CUDNN Frontend API  8.2.0
cudnn_frontend_Operation.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #pragma once
24 
25 #include <algorithm>
26 #include <array>
27 #include <functional>
28 #include <memory>
29 #include <sstream>
30 #include <utility>
31 
32 #include <cudnn.h>
33 #include <cudnn_backend.h>
34 
39 #include "cudnn_frontend_Tensor.h"
40 #include "cudnn_frontend_utils.h"
41 
42 namespace cudnn_frontend {
68  public:
69  friend class OperationBuilder_v8;
70  std::string
71  describe() const override {
72  std::stringstream ss;
73  ss << "CUDNN_BACKEND_OPERATION :"
74  << " OpMode: " << std::to_string(op_mode);
75  ss << std::hex << " X " << xdesc;
76  ss << std::hex << " Y " << ydesc;
77  ss << std::hex << " W " << wdesc;
78  ss << std::hex << " B " << bdesc;
79  ss << std::hex << " DW " << dwdesc;
80  ss << std::hex << " DY " << dydesc;
81  ss << std::hex << " DX " << dxdesc;
82  ss << std::hex << " C " << cdesc;
83  ss << std::hex << " A Mtrix " << amatdesc;
84  ss << std::hex << " B Mtrix " << bmatdesc;
85  ss << std::hex << " C Mtrix " << cmatdesc;
86  ss << std::hex << " P " << pwdesc;
87  ss << std::hex << " MatMul " << matmuldesc;
88  ss << std::hex << " Reduction " << reductiondesc;
89  ss << std::dec << " alphabetaType " << alphabetaType;
90  ss << " Alpha: " << alpha_s << " " << alpha_d;
91  ss << " Alpha2: " << alpha2_s << " " << alpha2_d;
92  ss << " Beta: " << beta_s << " " << beta_d;
93  return ss.str();
94  }
96  : BackendDescriptor(from.pointer, from.get_status(), from.get_error()),
97  op_mode(from.op_mode),
98  xdesc(from.xdesc),
99  ydesc(from.ydesc),
100  wdesc(from.wdesc),
101  bdesc(from.bdesc),
102  dydesc(from.dydesc),
103  dxdesc(from.dxdesc),
104  dwdesc(from.dwdesc),
105  cdesc(from.cdesc),
106  amatdesc(from.amatdesc),
107  bmatdesc(from.bmatdesc),
108  cmatdesc(from.cmatdesc),
109  pwdesc(from.pwdesc),
110  matmuldesc(from.matmuldesc),
113  alpha_s(from.alpha_s),
114  beta_s(from.beta_s),
115  alpha2_s(from.alpha2_s),
116  alpha_d(from.alpha_d),
117  beta_d(from.beta_d),
118  alpha2_d(from.alpha2_d),
121  operationTag(from.operationTag) {}
122 
125  return (op_mode == CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) ? cmatdesc : ydesc;
126  }
127 
128  std::string const &
129  getTag() const {
130  return operationTag;
131  }
132 
133  ~Operation_v8() = default;
134 
135  private:
136  Operation_v8() = default;
137  Operation_v8(Operation_v8 const &) = delete;
138  Operation_v8 &
139  operator=(Operation_v8 const &) = delete;
140 
141  cudnnBackendDescriptorType_t op_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
142 
157 
158  cudnnBackendAttributeType_t alphabetaType = CUDNN_TYPE_FLOAT;
159  float alpha_s = 1.0f, beta_s = .0f, alpha2_s = 1.0f;
160  double alpha_d = 1.0, beta_d = 0.0, alpha2_d = 1.0;
161  int64_t pointwise_port_count = -1;
162  cudnnPointwiseMode_t pointwise_mode;
165  bool is_pointwise_math_op = false;
166  std::string operationTag;
167 };
168 
172 
174  private:
176  bool is_convolution_op = false;
177  bool is_pointwise_op = false;
178  bool is_matmul_op = false;
179  bool is_reduction_op = false;
180 
181  public:
186  auto
188  m_operation.xdesc = raw_tensor;
189  return *this;
190  }
191 
192  auto
193  setxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & {
194  m_operation.xdesc = tensor.get_desc();
195  return *this;
196  }
197  auto
198  setbDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & {
199  if (is_pointwise_op == false) {
201  &m_operation,
202  CUDNN_STATUS_BAD_PARAM,
203  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Pointwise operation does not need bTensor");
204  }
205  m_operation.bdesc = tensor.get_desc();
206  return *this;
207  }
208  auto
209  setyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & {
210  m_operation.ydesc = tensor.get_desc();
211  return *this;
212  }
213  auto
214  setwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & {
215  if (is_convolution_op == false) {
217  &m_operation,
218  CUDNN_STATUS_BAD_PARAM,
219  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Convolution operation does not need wTensor");
220  }
221  m_operation.wdesc = tensor.get_desc();
222  return *this;
223  }
224 
225  auto
227  m_operation.dydesc = raw_tensor;
228  return *this;
229  }
230  auto
232  m_operation.dydesc = tensor.get_desc();
233  return *this;
234  }
235  auto
237  m_operation.dxdesc = tensor.get_desc();
238  return *this;
239  }
240  auto
242  m_operation.dwdesc = tensor.get_desc();
243  return *this;
244  }
245 
246  auto
248  if (is_convolution_op == false) {
250  &m_operation,
251  CUDNN_STATUS_BAD_PARAM,
252  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Convolution operation does not need Convolution DESCRIPTOR");
253  }
254  m_operation.cdesc = conv.get_desc();
255  return *this;
256  }
257  auto
259  if (is_matmul_op == false) {
261  &m_operation,
262  CUDNN_STATUS_BAD_PARAM,
263  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need a Matrix Tensor");
264  }
265  m_operation.amatdesc = tensor.get_desc();
266  return *this;
267  }
268  auto
270  if (is_matmul_op == false) {
272  &m_operation,
273  CUDNN_STATUS_BAD_PARAM,
274  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need b Matrix Tensor");
275  }
276  m_operation.bmatdesc = tensor.get_desc();
277  return *this;
278  }
279  auto
281  if (is_matmul_op == false) {
283  &m_operation,
284  CUDNN_STATUS_BAD_PARAM,
285  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need c Matrix Tensor");
286  }
287  m_operation.cmatdesc = tensor.get_desc();
288  return *this;
289  }
290  auto
292  if (is_matmul_op == false) {
294  &m_operation,
295  CUDNN_STATUS_BAD_PARAM,
296  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need MATMUL DESCRIPTOR");
297  }
298  m_operation.matmuldesc = matmulDesc.get_desc();
299  return *this;
300  }
301  auto
303  if (is_reduction_op == false) {
305  &m_operation,
306  CUDNN_STATUS_BAD_PARAM,
307  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Reduction operation does not need REDUCTION DESCRIPTOR");
308  }
309  m_operation.reductiondesc = reductionDesc.get_desc();
310  return *this;
311  }
312  auto
313  setpwDesc(PointWiseDesc_v8 const &pointWiseDesc) -> OperationBuilder_v8 & {
314  if (is_pointwise_op == false) {
316  &m_operation,
317  CUDNN_STATUS_BAD_PARAM,
318  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Pointwise operation does not need POINTWISE DESCRIPTOR");
319  }
320  m_operation.pwdesc = pointWiseDesc.get_desc();
321  m_operation.pointwise_port_count = pointWiseDesc.getPortCount();
322  m_operation.pointwise_mode = pointWiseDesc.getPointWiseMode();
323 
324  m_operation.is_pointwise_math_op = ((m_operation.pointwise_mode == CUDNN_POINTWISE_ADD) ||
325  (m_operation.pointwise_mode == CUDNN_POINTWISE_MUL) ||
326  (m_operation.pointwise_mode == CUDNN_POINTWISE_MIN) ||
327  (m_operation.pointwise_mode == CUDNN_POINTWISE_MAX) ||
328  (m_operation.pointwise_mode == CUDNN_POINTWISE_SQRT));
329 
330  m_operation.is_pointwise_activation_fwd_op = ((m_operation.pointwise_mode == CUDNN_POINTWISE_RELU_FWD) ||
331  (m_operation.pointwise_mode == CUDNN_POINTWISE_TANH_FWD) ||
332  (m_operation.pointwise_mode == CUDNN_POINTWISE_SIGMOID_FWD) ||
333  (m_operation.pointwise_mode == CUDNN_POINTWISE_ELU_FWD) ||
334  (m_operation.pointwise_mode == CUDNN_POINTWISE_GELU_FWD) ||
335  (m_operation.pointwise_mode == CUDNN_POINTWISE_SOFTPLUS_FWD) ||
336  (m_operation.pointwise_mode == CUDNN_POINTWISE_SWISH_FWD));
337 
338  m_operation.is_pointwise_activation_bwd_op = ((m_operation.pointwise_mode == CUDNN_POINTWISE_RELU_BWD) ||
339  (m_operation.pointwise_mode == CUDNN_POINTWISE_TANH_BWD) ||
340  (m_operation.pointwise_mode == CUDNN_POINTWISE_SIGMOID_BWD) ||
341  (m_operation.pointwise_mode == CUDNN_POINTWISE_ELU_BWD) ||
342  (m_operation.pointwise_mode == CUDNN_POINTWISE_GELU_BWD) ||
343  (m_operation.pointwise_mode == CUDNN_POINTWISE_SOFTPLUS_BWD) ||
344  (m_operation.pointwise_mode == CUDNN_POINTWISE_SWISH_BWD));
345 
346  return *this;
347  }
348 
349  auto
350  setAlpha(float alpha) -> OperationBuilder_v8 & {
351  m_operation.alphabetaType = CUDNN_TYPE_FLOAT;
352  m_operation.alpha_d = static_cast<double>(alpha);
353  m_operation.alpha_s = alpha;
354  return *this;
355  }
356  auto
357  setAlpha(double alpha) -> OperationBuilder_v8 & {
358  m_operation.alphabetaType = CUDNN_TYPE_DOUBLE;
359  m_operation.alpha_s = static_cast<float>(alpha);
360  m_operation.alpha_d = alpha;
361  return *this;
362  }
363  auto
364  setAlpha2(float alpha) -> OperationBuilder_v8 & {
365  m_operation.alphabetaType = CUDNN_TYPE_FLOAT;
366  m_operation.alpha2_d = static_cast<double>(alpha);
367  m_operation.alpha2_s = alpha;
368  return *this;
369  }
370  auto
371  setAlpha2(double alpha) -> OperationBuilder_v8 & {
372  m_operation.alphabetaType = CUDNN_TYPE_DOUBLE;
373  m_operation.alpha2_s = static_cast<float>(alpha);
374  m_operation.alpha2_d = alpha;
375  return *this;
376  }
377  auto
378  setBeta(float beta) -> OperationBuilder_v8 & {
379  m_operation.alphabetaType = CUDNN_TYPE_FLOAT;
380  m_operation.beta_d = static_cast<double>(beta);
381  m_operation.beta_s = beta;
382  return *this;
383  }
384  auto
385  setBeta(double beta) -> OperationBuilder_v8 & {
386  m_operation.alphabetaType = CUDNN_TYPE_DOUBLE;
387  m_operation.beta_s = static_cast<float>(beta);
388  m_operation.beta_d = beta;
389  return *this;
390  }
391 
392  OperationBuilder_v8(cudnnBackendDescriptorType_t mode) {
393  m_operation.op_mode = mode;
394  is_convolution_op = ((m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) ||
395  (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) ||
396  (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR));
397 
398  is_pointwise_op = (m_operation.op_mode == CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR);
399  is_matmul_op = (m_operation.op_mode == CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
400  is_reduction_op = (m_operation.op_mode == CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR);
401  }
404  Operation_v8 &&
407  build() {
408  if (m_operation.status != CUDNN_STATUS_SUCCESS) {
410  &m_operation, m_operation.status, "CUDNN_BACKEND_OPERATION: Operation not initialized properly");
411  return std::move(m_operation);
412  }
413 
414  if (is_convolution_op) {
415  if (m_operation.cdesc == nullptr) {
417  &m_operation,
418  CUDNN_STATUS_BAD_PARAM,
419  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_CONV_DESC");
420  return std::move(m_operation);
421  }
422  if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
423  if (m_operation.xdesc == nullptr) {
425  &m_operation,
426  CUDNN_STATUS_BAD_PARAM,
427  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_X");
428  return std::move(m_operation);
429  }
430  if (m_operation.wdesc == nullptr) {
432  &m_operation,
433  CUDNN_STATUS_BAD_PARAM,
434  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_W");
435  return std::move(m_operation);
436  }
437  if (m_operation.ydesc == nullptr) {
439  &m_operation,
440  CUDNN_STATUS_BAD_PARAM,
441  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_Y");
442  return std::move(m_operation);
443  }
444 
445  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
446  if (m_operation.ydesc != nullptr && m_operation.dydesc != nullptr) {
447  set_error_and_throw_exception(&m_operation,
448  CUDNN_STATUS_BAD_PARAM,
449  "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set "
450  "only one of setyDesc() or setdyDesc()");
451  return std::move(m_operation);
452  }
453  if (m_operation.ydesc == nullptr && m_operation.dydesc == nullptr) {
455  &m_operation,
456  CUDNN_STATUS_BAD_PARAM,
457  "CUDNN_BACKEND_OPERATION: Choose and Set one of setyDesc() or setdyDesc()");
458  return std::move(m_operation);
459  }
460  if (m_operation.xdesc == nullptr) {
462  &m_operation,
463  CUDNN_STATUS_BAD_PARAM,
464  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_X");
465  return std::move(m_operation);
466  }
467  if (m_operation.wdesc != nullptr && m_operation.dwdesc != nullptr) {
468  set_error_and_throw_exception(&m_operation,
469  CUDNN_STATUS_BAD_PARAM,
470  "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set "
471  "only one of setwDesc() or setdwDesc()");
472  return std::move(m_operation);
473  }
474  if (m_operation.wdesc == nullptr && m_operation.dwdesc == nullptr) {
476  &m_operation,
477  CUDNN_STATUS_BAD_PARAM,
478  "CUDNN_BACKEND_OPERATION: Choose and Set one of setwDesc() or setdwDesc()");
479  return std::move(m_operation);
480  }
481  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
482  if (m_operation.ydesc != nullptr && m_operation.dydesc != nullptr) {
483  set_error_and_throw_exception(&m_operation,
484  CUDNN_STATUS_BAD_PARAM,
485  "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set "
486  "only one of setyDesc() or setdyDesc()");
487  return std::move(m_operation);
488  }
489  if (m_operation.ydesc == nullptr && m_operation.dydesc == nullptr) {
491  &m_operation,
492  CUDNN_STATUS_BAD_PARAM,
493  "CUDNN_BACKEND_OPERATION: Choose and Set one of setyDesc() or setdyDesc()");
494  return std::move(m_operation);
495  }
496  if (m_operation.wdesc == nullptr) {
498  &m_operation,
499  CUDNN_STATUS_BAD_PARAM,
500  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_W");
501  return std::move(m_operation);
502  }
503  if (m_operation.xdesc != nullptr && m_operation.dxdesc != nullptr) {
504  set_error_and_throw_exception(&m_operation,
505  CUDNN_STATUS_BAD_PARAM,
506  "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set "
507  "only one of setxDesc() or setdxDesc()");
508  return std::move(m_operation);
509  }
510  if (m_operation.xdesc == nullptr && m_operation.dxdesc == nullptr) {
512  &m_operation,
513  CUDNN_STATUS_BAD_PARAM,
514  "CUDNN_BACKEND_OPERATION: Choose and Set one of setxDesc() or setdxDesc()");
515  return std::move(m_operation);
516  }
517  } else {
518  set_error_and_throw_exception(&m_operation,
519  CUDNN_STATUS_BAD_PARAM,
520  "CUDNN_BACKEND_OPERATION: Unsupported convolution operation. Check and "
521  "set CUDNN_BACKEND_OPERATION_CONVOLUTION_*_DESCRIPTOR");
522  return std::move(m_operation);
523  }
524  } else if (is_pointwise_op) {
525  if (m_operation.xdesc == nullptr) {
527  &m_operation,
528  CUDNN_STATUS_BAD_PARAM,
529  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_XDESC");
530  return std::move(m_operation);
531  }
532 
533  if (m_operation.is_pointwise_math_op) {
534  if (m_operation.pointwise_port_count == 3 && m_operation.bdesc == nullptr) {
536  &m_operation,
537  CUDNN_STATUS_BAD_PARAM,
538  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_BDESC");
539  return std::move(m_operation);
540  }
541  if (m_operation.ydesc == nullptr) {
543  &m_operation,
544  CUDNN_STATUS_BAD_PARAM,
545  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_YDESC");
546  return std::move(m_operation);
547  }
548  } else if (m_operation.is_pointwise_activation_fwd_op) {
549  if (m_operation.ydesc == nullptr) {
551  &m_operation,
552  CUDNN_STATUS_BAD_PARAM,
553  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_YDESC");
554  return std::move(m_operation);
555  }
556  } else if (m_operation.is_pointwise_activation_bwd_op) {
557  if (m_operation.dydesc == nullptr) {
559  &m_operation,
560  CUDNN_STATUS_BAD_PARAM,
561  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_DYDESC");
562  return std::move(m_operation);
563  }
564  if (m_operation.dxdesc == nullptr) {
566  &m_operation,
567  CUDNN_STATUS_BAD_PARAM,
568  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_DXDESC");
569  return std::move(m_operation);
570  }
571  } else {
573  &m_operation,
574  CUDNN_STATUS_BAD_PARAM,
575  "CUDNN_BACKEND_OPERATION: Unsupported cudnn pointwise mode. Check and set CUDNN_POINTWISE_*");
576  return std::move(m_operation);
577  }
578 
579  } else if (is_matmul_op) {
580  if (m_operation.matmuldesc == nullptr) {
582  &m_operation,
583  CUDNN_STATUS_BAD_PARAM,
584  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_DESC");
585  return std::move(m_operation);
586  }
587  if (m_operation.amatdesc == nullptr) {
589  &m_operation,
590  CUDNN_STATUS_BAD_PARAM,
591  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_ADESC");
592  return std::move(m_operation);
593  }
594  if (m_operation.bmatdesc == nullptr) {
596  &m_operation,
597  CUDNN_STATUS_BAD_PARAM,
598  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_BDESC");
599  return std::move(m_operation);
600  }
601  if (m_operation.cmatdesc == nullptr) {
603  &m_operation,
604  CUDNN_STATUS_BAD_PARAM,
605  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_CDESC");
606  return std::move(m_operation);
607  }
608  } else if (is_reduction_op) {
609  if (m_operation.reductiondesc == nullptr) {
611  &m_operation,
612  CUDNN_STATUS_BAD_PARAM,
613  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_DESC");
614  return std::move(m_operation);
615  }
616  if (m_operation.xdesc == nullptr) {
618  &m_operation,
619  CUDNN_STATUS_BAD_PARAM,
620  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_XDESC");
621  return std::move(m_operation);
622  }
623  if (m_operation.ydesc == nullptr) {
625  &m_operation,
626  CUDNN_STATUS_BAD_PARAM,
627  "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_YDESC");
628  return std::move(m_operation);
629  }
630  } else {
631  set_error_and_throw_exception(&m_operation,
632  CUDNN_STATUS_BAD_PARAM,
633  "CUDNN_BACKEND_OPERATION_DESCRIPTOR: Unsupported cudnn backend descriptor "
634  "type. Check and set CUDNN_BACKEND_OPERATION_*_DESCRIPTOR");
635  return std::move(m_operation);
636  }
637 
638  // Create the descriptor.
639  auto status = m_operation.initialize_managed_backend_pointer(m_operation.op_mode);
640  if (status != CUDNN_STATUS_SUCCESS) {
641  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnCreate Failed");
642  return std::move(m_operation);
643  }
644 
645  if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
646  m_operation.operationTag = "ConvFwd";
647 
648  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
649  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X,
650  CUDNN_TYPE_BACKEND_DESCRIPTOR,
651  1,
652  &(m_operation.xdesc->get_backend_descriptor()));
653  if (status != CUDNN_STATUS_SUCCESS) {
655  &m_operation,
656  status,
657  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X Failed");
658  return std::move(m_operation);
659  }
660  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
661  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W,
662  CUDNN_TYPE_BACKEND_DESCRIPTOR,
663  1,
664  &(m_operation.wdesc->get_backend_descriptor()));
665  if (status != CUDNN_STATUS_SUCCESS) {
667  &m_operation,
668  status,
669  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W Failed");
670  return std::move(m_operation);
671  }
672  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
673  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y,
674  CUDNN_TYPE_BACKEND_DESCRIPTOR,
675  1,
676  &(m_operation.ydesc->get_backend_descriptor()));
677  if (status != CUDNN_STATUS_SUCCESS) {
679  &m_operation,
680  status,
681  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y Failed");
682  return std::move(m_operation);
683  }
684  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
685  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC,
686  CUDNN_TYPE_BACKEND_DESCRIPTOR,
687  1,
688  &(m_operation.cdesc->get_backend_descriptor()));
689  if (status != CUDNN_STATUS_SUCCESS) {
691  &m_operation,
692  status,
693  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC Failed");
694  return std::move(m_operation);
695  }
696  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
697  : static_cast<void *>(&m_operation.alpha_d));
698  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
699  : static_cast<void *>(&m_operation.beta_d));
700  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
701  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA,
702  m_operation.alphabetaType,
703  1,
704  alpha);
705  if (status != CUDNN_STATUS_SUCCESS) {
707  &m_operation,
708  status,
709  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA Failed");
710  return std::move(m_operation);
711  }
712  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
713  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA,
714  m_operation.alphabetaType,
715  1,
716  beta);
717  if (status != CUDNN_STATUS_SUCCESS) {
719  &m_operation,
720  status,
721  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA Failed");
722  return std::move(m_operation);
723  }
724  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
725  m_operation.operationTag = "ConvBwdFilter";
726 
727  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
728  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X,
729  CUDNN_TYPE_BACKEND_DESCRIPTOR,
730  1,
731  &(m_operation.xdesc->get_backend_descriptor()));
732  if (status != CUDNN_STATUS_SUCCESS) {
734  &m_operation,
735  status,
736  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X Failed");
737  return std::move(m_operation);
738  }
739 
740  auto dwdesc_ = m_operation.dwdesc != nullptr ? m_operation.dwdesc : m_operation.wdesc;
741  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
742  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW,
743  CUDNN_TYPE_BACKEND_DESCRIPTOR,
744  1,
745  &(dwdesc_->get_backend_descriptor()));
746  if (status != CUDNN_STATUS_SUCCESS) {
748  &m_operation,
749  status,
750  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW Failed");
751  return std::move(m_operation);
752  }
753 
754  auto dydesc_ = m_operation.dydesc != nullptr ? m_operation.dydesc : m_operation.ydesc;
755  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
756  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY,
757  CUDNN_TYPE_BACKEND_DESCRIPTOR,
758  1,
759  &(dydesc_->get_backend_descriptor()));
760  if (status != CUDNN_STATUS_SUCCESS) {
762  &m_operation,
763  status,
764  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY Failed");
765  return std::move(m_operation);
766  }
767 
768  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
769  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC,
770  CUDNN_TYPE_BACKEND_DESCRIPTOR,
771  1,
772  &(m_operation.cdesc->get_backend_descriptor()));
773  if (status != CUDNN_STATUS_SUCCESS) {
774  set_error_and_throw_exception(&m_operation,
775  status,
776  "CUDNN_BACKEND_OPERATION: SetAttribute "
777  "CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC Failed");
778  return std::move(m_operation);
779  }
780  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
781  : static_cast<void *>(&m_operation.alpha_d));
782  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
783  : static_cast<void *>(&m_operation.beta_d));
784  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
785  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA,
786  m_operation.alphabetaType,
787  1,
788  alpha);
789  if (status != CUDNN_STATUS_SUCCESS) {
791  &m_operation,
792  status,
793  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA Failed");
794  return std::move(m_operation);
795  }
796  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
797  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA,
798  m_operation.alphabetaType,
799  1,
800  beta);
801  if (status != CUDNN_STATUS_SUCCESS) {
803  &m_operation,
804  status,
805  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA Failed");
806  return std::move(m_operation);
807  }
808  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
809  m_operation.operationTag = "ConvBwdData";
810 
811  auto dxdesc_ = m_operation.dxdesc != nullptr ? m_operation.dxdesc : m_operation.xdesc;
812  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
813  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX,
814  CUDNN_TYPE_BACKEND_DESCRIPTOR,
815  1,
816  &(dxdesc_->get_backend_descriptor()));
817  if (status != CUDNN_STATUS_SUCCESS) {
819  &m_operation,
820  status,
821  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX Failed");
822  return std::move(m_operation);
823  }
824 
825  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
826  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W,
827  CUDNN_TYPE_BACKEND_DESCRIPTOR,
828  1,
829  &(m_operation.wdesc->get_backend_descriptor()));
830  if (status != CUDNN_STATUS_SUCCESS) {
832  &m_operation,
833  status,
834  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W Failed");
835  return std::move(m_operation);
836  }
837 
838  auto dydesc_ = m_operation.dydesc != nullptr ? m_operation.dydesc : m_operation.ydesc;
839  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
840  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY,
841  CUDNN_TYPE_BACKEND_DESCRIPTOR,
842  1,
843  &(dydesc_->get_backend_descriptor()));
844  if (status != CUDNN_STATUS_SUCCESS) {
846  &m_operation,
847  status,
848  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY Failed");
849  return std::move(m_operation);
850  }
851 
852  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
853  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC,
854  CUDNN_TYPE_BACKEND_DESCRIPTOR,
855  1,
856  &(m_operation.cdesc->get_backend_descriptor()));
857  if (status != CUDNN_STATUS_SUCCESS) {
859  &m_operation,
860  status,
861  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC Failed");
862  return std::move(m_operation);
863  }
864 
865  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
866  : static_cast<void *>(&m_operation.alpha_d));
867  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
868  : static_cast<void *>(&m_operation.beta_d));
869  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
870  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA,
871  m_operation.alphabetaType,
872  1,
873  alpha);
874  if (status != CUDNN_STATUS_SUCCESS) {
876  &m_operation,
877  status,
878  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA Failed");
879  return std::move(m_operation);
880  }
881  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
882  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA,
883  m_operation.alphabetaType,
884  1,
885  beta);
886  if (status != CUDNN_STATUS_SUCCESS) {
888  &m_operation,
889  status,
890  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA Failed");
891  return std::move(m_operation);
892  }
893  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) {
894  switch (m_operation.pointwise_mode) {
895  case CUDNN_POINTWISE_ADD:
896  m_operation.operationTag = "Add";
897  break;
898  case CUDNN_POINTWISE_MUL:
899  m_operation.operationTag = "Mul";
900  break;
901  case CUDNN_POINTWISE_MIN:
902  m_operation.operationTag = "Min";
903  break;
904  case CUDNN_POINTWISE_MAX:
905  m_operation.operationTag = "Max";
906  break;
907  case CUDNN_POINTWISE_SQRT:
908  m_operation.operationTag = "Sqrt";
909  break;
910  case CUDNN_POINTWISE_RELU_FWD:
911  m_operation.operationTag = "ReluFwd";
912  break;
913  case CUDNN_POINTWISE_TANH_FWD:
914  m_operation.operationTag = "TanhFwd";
915  break;
916  case CUDNN_POINTWISE_SIGMOID_FWD:
917  m_operation.operationTag = "SigmoidFwd";
918  break;
919  case CUDNN_POINTWISE_ELU_FWD:
920  m_operation.operationTag = "EluFwd";
921  break;
922  case CUDNN_POINTWISE_GELU_FWD:
923  m_operation.operationTag = "GeluFwd";
924  break;
925  case CUDNN_POINTWISE_SOFTPLUS_FWD:
926  m_operation.operationTag = "SoftplusFwd";
927  break;
928  case CUDNN_POINTWISE_SWISH_FWD:
929  m_operation.operationTag = "SwishFwd";
930  break;
931  case CUDNN_POINTWISE_RELU_BWD:
932  m_operation.operationTag = "ReluBwd";
933  break;
934  case CUDNN_POINTWISE_TANH_BWD:
935  m_operation.operationTag = "TanhBwd";
936  break;
937  case CUDNN_POINTWISE_SIGMOID_BWD:
938  m_operation.operationTag = "SigmoidBwd";
939  break;
940  case CUDNN_POINTWISE_ELU_BWD:
941  m_operation.operationTag = "EluBwd";
942  break;
943  case CUDNN_POINTWISE_GELU_BWD:
944  m_operation.operationTag = "GeluBwd";
945  break;
946  case CUDNN_POINTWISE_SOFTPLUS_BWD:
947  m_operation.operationTag = "SoftplusBwd";
948  break;
949  case CUDNN_POINTWISE_SWISH_BWD:
950  m_operation.operationTag = "SwishBwd";
951  break;
952  default:
953  m_operation.operationTag = "OtherOp";
954  break;
955  }
956 
957  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
958  CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR,
959  CUDNN_TYPE_BACKEND_DESCRIPTOR,
960  1,
961  &(m_operation.pwdesc->get_backend_descriptor()));
962  if (status != CUDNN_STATUS_SUCCESS) {
964  &m_operation,
965  status,
966  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR Failed");
967  return std::move(m_operation);
968  }
969 
970  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
971  CUDNN_ATTR_OPERATION_POINTWISE_XDESC,
972  CUDNN_TYPE_BACKEND_DESCRIPTOR,
973  1,
974  &(m_operation.xdesc->get_backend_descriptor()));
975  if (status != CUDNN_STATUS_SUCCESS) {
977  &m_operation,
978  status,
979  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_XDESC Failed");
980  return std::move(m_operation);
981  }
982 
983  if (!m_operation.is_pointwise_activation_bwd_op) {
984  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
985  CUDNN_ATTR_OPERATION_POINTWISE_YDESC,
986  CUDNN_TYPE_BACKEND_DESCRIPTOR,
987  1,
988  &(m_operation.ydesc->get_backend_descriptor()));
989  if (status != CUDNN_STATUS_SUCCESS) {
991  &m_operation,
992  status,
993  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_YDESC Failed");
994  return std::move(m_operation);
995  }
996  } else {
997  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
998  CUDNN_ATTR_OPERATION_POINTWISE_DYDESC,
999  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1000  1,
1001  &(m_operation.dydesc->get_backend_descriptor()));
1002  if (status != CUDNN_STATUS_SUCCESS) {
1004  &m_operation,
1005  status,
1006  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_DYDESC Failed");
1007  return std::move(m_operation);
1008  }
1009 
1010  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1011  CUDNN_ATTR_OPERATION_POINTWISE_DXDESC,
1012  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1013  1,
1014  &(m_operation.dxdesc->get_backend_descriptor()));
1015  if (status != CUDNN_STATUS_SUCCESS) {
1017  &m_operation,
1018  status,
1019  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_DXDESC Failed");
1020  return std::move(m_operation);
1021  }
1022  }
1023 
1024  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
1025  : static_cast<void *>(&m_operation.alpha_d));
1026  void *alpha2 = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha2_s)
1027  : static_cast<void *>(&m_operation.alpha2_d));
1028  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1029  CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1,
1030  m_operation.alphabetaType,
1031  1,
1032  alpha);
1033  if (status != CUDNN_STATUS_SUCCESS) {
1035  &m_operation,
1036  status,
1037  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 Failed");
1038  return std::move(m_operation);
1039  }
1040  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1041  CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2,
1042  m_operation.alphabetaType,
1043  1,
1044  alpha2);
1045  if (status != CUDNN_STATUS_SUCCESS) {
1047  &m_operation,
1048  status,
1049  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 Failed");
1050  return std::move(m_operation);
1051  }
1052 
1053  if (m_operation.pointwise_port_count == 3 && !m_operation.is_pointwise_activation_bwd_op) {
1054  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1055  CUDNN_ATTR_OPERATION_POINTWISE_BDESC,
1056  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1057  1,
1058  &(m_operation.bdesc->get_backend_descriptor()));
1059  if (status != CUDNN_STATUS_SUCCESS) {
1061  &m_operation,
1062  status,
1063  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_BDESC Failed");
1064  return std::move(m_operation);
1065  }
1066  }
1067  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) {
1068  m_operation.operationTag = "Matmul";
1069  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1070  CUDNN_ATTR_OPERATION_MATMUL_ADESC,
1071  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1072  1,
1073  &(m_operation.amatdesc->get_backend_descriptor()));
1074  if (status != CUDNN_STATUS_SUCCESS) {
1076  &m_operation,
1077  status,
1078  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_ADESC Failed");
1079  return std::move(m_operation);
1080  }
1081  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1082  CUDNN_ATTR_OPERATION_MATMUL_BDESC,
1083  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1084  1,
1085  &(m_operation.bmatdesc->get_backend_descriptor()));
1086  if (status != CUDNN_STATUS_SUCCESS) {
1088  &m_operation,
1089  status,
1090  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_BDESC Failed");
1091  return std::move(m_operation);
1092  }
1093  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1094  CUDNN_ATTR_OPERATION_MATMUL_CDESC,
1095  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1096  1,
1097  &(m_operation.cmatdesc->get_backend_descriptor()));
1098  if (status != CUDNN_STATUS_SUCCESS) {
1100  &m_operation,
1101  status,
1102  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_CDESC Failed");
1103  return std::move(m_operation);
1104  }
1105  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1106  CUDNN_ATTR_OPERATION_MATMUL_DESC,
1107  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1108  1,
1109  &(m_operation.matmuldesc->get_backend_descriptor()));
1110  if (status != CUDNN_STATUS_SUCCESS) {
1112  &m_operation,
1113  status,
1114  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_DESC Failed");
1115  return std::move(m_operation);
1116  }
1117  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) {
1118  m_operation.operationTag = "Reduction";
1119  if ((cudnnGetVersion() / 100) == 81) { // workaround for cudnn 8.1
1120  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1121  CUDNN_ATTR_REDUCTION_OPERATOR,
1122  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1123  1,
1124  &(m_operation.reductiondesc->get_backend_descriptor()));
1125  } else {
1126  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1127  CUDNN_ATTR_OPERATION_REDUCTION_DESC,
1128  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1129  1,
1130  &(m_operation.reductiondesc->get_backend_descriptor()));
1131  }
1132  if (status != CUDNN_STATUS_SUCCESS) {
1134  &m_operation,
1135  status,
1136  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_DESC Failed");
1137  return std::move(m_operation);
1138  }
1139  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1140  CUDNN_ATTR_OPERATION_REDUCTION_XDESC,
1141  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1142  1,
1143  &(m_operation.xdesc->get_backend_descriptor()));
1144  if (status != CUDNN_STATUS_SUCCESS) {
1146  &m_operation,
1147  status,
1148  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_XDESC Failed");
1149  return std::move(m_operation);
1150  }
1151  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
1152  CUDNN_ATTR_OPERATION_REDUCTION_YDESC,
1153  CUDNN_TYPE_BACKEND_DESCRIPTOR,
1154  1,
1155  &(m_operation.ydesc->get_backend_descriptor()));
1156  if (status != CUDNN_STATUS_SUCCESS) {
1158  &m_operation,
1159  status,
1160  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_YDESC Failed");
1161  return std::move(m_operation);
1162  }
1163  }
1164  status = cudnnBackendFinalize(m_operation.pointer->get_backend_descriptor());
1165  if (status != CUDNN_STATUS_SUCCESS) {
1166  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
1167  return std::move(m_operation);
1168  }
1169  return std::move(m_operation);
1170  }
1171 };
1172 }
auto setcDesc(ConvDesc_v8 const &conv) -> OperationBuilder_v8 &
cudnnStatus_t initialize_managed_backend_pointer(cudnnBackendDescriptorType_t type)
Initializes the underlying managed descriptor.
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
NLOHMANN_BASIC_JSON_TPL_DECLARATION std::string to_string(const NLOHMANN_BASIC_JSON_TPL &j)
user-defined to_string function for JSON values
Definition: json.hpp:25855
auto setAlpha(float alpha) -> OperationBuilder_v8 &
auto setdxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
Operation_v8 & operator=(Operation_v8 const &)=delete
auto setbDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setaMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setmatmulDesc(MatMulDesc_v8 const &matmulDesc) -> OperationBuilder_v8 &
cudnnBackendDescriptorType_t op_mode
auto setBeta(float beta) -> OperationBuilder_v8 &
auto setpwDesc(PointWiseDesc_v8 const &pointWiseDesc) -> OperationBuilder_v8 &
auto setAlpha2(float alpha) -> OperationBuilder_v8 &
auto setdwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setreductionDesc(ReductionDesc_v8 const &reductionDesc) -> OperationBuilder_v8 &
cudnnStatus_t get_status() const
Current status of the descriptor.
auto setBeta(double beta) -> OperationBuilder_v8 &
auto setbMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor reductiondesc
std::shared_ptr< OpaqueBackendPointer > ManagedOpaqueDescriptor
auto setdyDesc(ManagedOpaqueDescriptor const &raw_tensor) -> OperationBuilder_v8 &
std::string describe() const override
Return a string describing the backend Descriptor.
auto setdyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
const char * get_error() const
Diagonistic error message if any.
cudnnBackendAttributeType_t alphabetaType
auto setxDesc(ManagedOpaqueDescriptor const &raw_tensor) -> OperationBuilder_v8 &
auto setcMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor getOutputTensor()
auto setAlpha2(double alpha) -> OperationBuilder_v8 &
OperationBuilder_v8(cudnnBackendDescriptorType_t mode)
std::string const & getTag() const
auto setAlpha(double alpha) -> OperationBuilder_v8 &
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.