CUDNN Frontend API  8.2.0
cudnn_frontend_OperationGraph.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #pragma once
24 
25 #include <algorithm>
26 #include <array>
27 #include <functional>
28 #include <memory>
29 #include <sstream>
30 #include <utility>
31 
32 #include <cudnn.h>
33 #include <cudnn_backend.h>
34 
36 #include "cudnn_frontend_utils.h"
37 
38 namespace cudnn_frontend {
39 
52  public:
54  std::string
55  describe() const override {
56  std::stringstream ss;
57  ss << "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR :";
58  return ss.str();
59  }
60 
62  : BackendDescriptor(from.get_desc(), from.get_status(), from.get_error()),
63  handle(from.handle),
64  ops(from.ops),
65  numOps(from.numOps),
66  opGraphTag(from.opGraphTag) {}
67 
68  ~OperationGraph_v8() = default;
69 
74  auto
76  getEngineCount(void) const -> int64_t {
77  int64_t global_count = -1;
78  auto status = cudnnBackendGetAttribute(pointer->get_backend_descriptor(),
79  CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT,
80  CUDNN_TYPE_INT64,
81  1,
82  NULL,
83  &global_count);
84  if (status != CUDNN_STATUS_SUCCESS) {
86  status,
87  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: GetAttribute "
88  "CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT Failed");
89  }
90  return global_count;
91  }
94  uint64_t
95  getOpCount() const {
96  return numOps;
97  }
98 
99  std::string const &
100  getTag() const {
101  return opGraphTag;
102  }
103 
104  private:
105  OperationGraph_v8() = default;
106  OperationGraph_v8(OperationGraph_v8 const &) = delete;
108  operator=(OperationGraph_v8 const &) = delete;
109 
110  cudnnHandle_t handle = nullptr;
111  std::array<ManagedOpaqueDescriptor, 10> ops{};
112  int64_t numOps = -1;
113  std::string opGraphTag = "";
114 };
115 
120  public:
125  auto
127  setHandle(cudnnHandle_t handle_) -> OperationGraphBuilder_v8 & {
128  m_operationGraph.handle = handle_;
129  return *this;
130  }
132  auto
133  setOperationGraph(int64_t numOps_, Operation_v8 const **ops_) -> OperationGraphBuilder_v8 & {
134  m_operationGraph.numOps = numOps_;
135  for (auto i = 0u; i < numOps_; i++) {
136  m_operationGraph.ops[i] = ops_[i]->get_desc();
137  m_operationGraph.opGraphTag += ops_[i]->getTag() + '_';
138  }
139  return *this;
140  }
146  build() {
147  if (m_operationGraph.numOps <= 0) {
149  &m_operationGraph,
150  CUDNN_STATUS_BAD_PARAM,
151  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and Set the CUDNN_ATTR_OPERATIONGRAPH_OPS Count field");
152  return std::move(m_operationGraph);
153  }
154  if (m_operationGraph.ops[0] == nullptr) {
156  &m_operationGraph,
157  CUDNN_STATUS_BAD_PARAM,
158  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and set CUDNN_ATTR_OPERATIONGRAPH_OPS field");
159  return std::move(m_operationGraph);
160  }
161  if (m_operationGraph.handle == nullptr) {
163  &m_operationGraph,
164  CUDNN_STATUS_BAD_PARAM,
165  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and Set CUDNN_ATTR_OPERATIONGRAPH_HANDLE");
166  return std::move(m_operationGraph);
167  }
168 
169  // Create a descriptor. Memory allocation happens here.
170  auto status = m_operationGraph.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR);
171  if (status != CUDNN_STATUS_SUCCESS) {
173  &m_operationGraph, status, "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: cudnnCreate Failed");
174  return std::move(m_operationGraph);
175  }
176 
177  std::array<cudnnBackendDescriptor_t, 10> ops_raw{nullptr};
178  for (auto i = 0u; i < m_operationGraph.numOps; i++) {
179  ops_raw[i] = m_operationGraph.ops[i]->get_backend_descriptor();
180  }
181 
182  status = cudnnBackendSetAttribute(m_operationGraph.pointer->get_backend_descriptor(),
183  CUDNN_ATTR_OPERATIONGRAPH_OPS,
184  CUDNN_TYPE_BACKEND_DESCRIPTOR,
185  m_operationGraph.numOps,
186  ops_raw.data());
187  if (status != CUDNN_STATUS_SUCCESS) {
189  &m_operationGraph,
190  status,
191  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: SetAttribute CUDNN_ATTR_OPERATIONGRAPH_OPS Failed");
192  return std::move(m_operationGraph);
193  }
194  status = cudnnBackendSetAttribute(m_operationGraph.pointer->get_backend_descriptor(),
195  CUDNN_ATTR_OPERATIONGRAPH_HANDLE,
196  CUDNN_TYPE_HANDLE,
197  1,
198  &m_operationGraph.handle);
199  if (status != CUDNN_STATUS_SUCCESS) {
201  &m_operationGraph,
202  status,
203  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: SetAttribute CUDNN_ATTR_OPERATIONGRAPH_HANDLE Failed");
204  return std::move(m_operationGraph);
205  }
206 
207  // Finalizing the descriptor
208  status = cudnnBackendFinalize(m_operationGraph.pointer->get_backend_descriptor());
209  if (status != CUDNN_STATUS_SUCCESS) {
211  &m_operationGraph, status, "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: cudnnFinalize Failed");
212  return std::move(m_operationGraph);
213  }
214 
215  return std::move(m_operationGraph);
216  }
217 
218  explicit OperationGraphBuilder_v8() = default;
219  ~OperationGraphBuilder_v8() = default;
223  operator=(OperationGraphBuilder_v8 const &) = delete;
224 
225  private:
227 };
228 }
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setHandle(cudnnHandle_t handle_) -> OperationGraphBuilder_v8 &
Set cudnnHandle for the operations.
auto setOperationGraph(int64_t numOps_, Operation_v8 const **ops_) -> OperationGraphBuilder_v8 &
Set numoperations and the operations.
ManagedOpaqueDescriptor get_desc() const
Returns a copy of underlying managed descriptor.
auto getEngineCount(void) const -> int64_t
Query the total count of the engines for the Operation Set.
cudnnStatus_t get_status() const
Current status of the descriptor.
std::string describe() const override
Return a string describing the backend Descriptor.
const char * get_error() const
Diagonistic error message if any.
OperationGraph_v8 & operator=(OperationGraph_v8 const &)=delete
std::array< ManagedOpaqueDescriptor, 10 > ops
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.