CUDNN API  8
cudnn_frontend_OperationGraph.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #pragma once
24 
25 #include <algorithm>
26 #include <array>
27 #include <functional>
28 #include <memory>
29 #include <sstream>
30 #include <utility>
31 
32 #include <cudnn.h>
33 #include <cudnn_backend.h>
34 
36 #include "cudnn_frontend_utils.h"
37 
38 namespace cudnn_frontend {
39 
52  public:
54  std::string
55  describe() const override {
56  std::stringstream ss;
57  ss << "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR :";
58  return ss.str();
59  }
60 
62  : BackendDescriptor(from.get_desc(), from.get_status(), from.get_error()),
63  handle(from.handle),
64  ops(from.ops),
65  numOps(from.numOps),
66  opGraphTag(from.opGraphTag) {}
67 
68  ~OperationGraph_v8() = default;
69 
74  auto
76  getEngineCount(void) const -> int64_t {
77  int64_t global_count = -1;
78  auto status = cudnnBackendGetAttribute(pointer->get_backend_descriptor(),
79  CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT,
80  CUDNN_TYPE_INT64,
81  1,
82  NULL,
83  &global_count);
84  if (status != CUDNN_STATUS_SUCCESS) {
86  status,
87  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: GetAttribute "
88  "CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT Failed");
89  }
90  return global_count;
91  }
94  std::string const &
95  getTag() const {
96  return opGraphTag;
97  }
98 
99  private:
100  OperationGraph_v8() = default;
101  OperationGraph_v8(OperationGraph_v8 const &) = delete;
103  operator=(OperationGraph_v8 const &) = delete;
104 
105  cudnnHandle_t handle = nullptr;
106  std::array<ManagedOpaqueDescriptor, 10> ops{};
107  int64_t numOps = -1;
108  std::string opGraphTag = "";
109 };
110 
115  public:
120  auto
122  setHandle(cudnnHandle_t handle_) -> OperationGraphBuilder_v8 & {
123  m_operationGraph.handle = handle_;
124  return *this;
125  }
127  auto
128  setOperationGraph(int64_t numOps_, Operation_v8 const **ops_) -> OperationGraphBuilder_v8 & {
129  m_operationGraph.numOps = numOps_;
130  for (auto i = 0u; i < numOps_; i++) {
131  m_operationGraph.ops[i] = ops_[i]->get_desc();
132  m_operationGraph.opGraphTag += ops_[i]->getTag() + '_';
133  }
134  return *this;
135  }
141  build() {
142  if (m_operationGraph.numOps <= 0) {
144  &m_operationGraph,
145  CUDNN_STATUS_BAD_PARAM,
146  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and Set the CUDNN_ATTR_OPERATIONGRAPH_OPS Count field");
147  return std::move(m_operationGraph);
148  }
149  if (m_operationGraph.ops[0] == nullptr) {
151  &m_operationGraph,
152  CUDNN_STATUS_BAD_PARAM,
153  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and set CUDNN_ATTR_OPERATIONGRAPH_OPS field");
154  return std::move(m_operationGraph);
155  }
156  if (m_operationGraph.handle == nullptr) {
158  &m_operationGraph,
159  CUDNN_STATUS_BAD_PARAM,
160  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and Set CUDNN_ATTR_OPERATIONGRAPH_HANDLE");
161  return std::move(m_operationGraph);
162  }
163 
164  // Create a descriptor. Memory allocation happens here.
165  auto status = m_operationGraph.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR);
166  if (status != CUDNN_STATUS_SUCCESS) {
168  &m_operationGraph, status, "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: cudnnCreate Failed");
169  return std::move(m_operationGraph);
170  }
171 
172  std::array<cudnnBackendDescriptor_t, 10> ops_raw{nullptr};
173  for (auto i = 0u; i < m_operationGraph.numOps; i++) {
174  ops_raw[i] = m_operationGraph.ops[i]->get_backend_descriptor();
175  }
176 
177  status = cudnnBackendSetAttribute(m_operationGraph.pointer->get_backend_descriptor(),
178  CUDNN_ATTR_OPERATIONGRAPH_OPS,
179  CUDNN_TYPE_BACKEND_DESCRIPTOR,
180  m_operationGraph.numOps,
181  ops_raw.data());
182  if (status != CUDNN_STATUS_SUCCESS) {
184  &m_operationGraph,
185  status,
186  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: SetAttribute CUDNN_ATTR_OPERATIONGRAPH_OPS Failed");
187  return std::move(m_operationGraph);
188  }
189  status = cudnnBackendSetAttribute(m_operationGraph.pointer->get_backend_descriptor(),
190  CUDNN_ATTR_OPERATIONGRAPH_HANDLE,
191  CUDNN_TYPE_HANDLE,
192  1,
193  &m_operationGraph.handle);
194  if (status != CUDNN_STATUS_SUCCESS) {
196  &m_operationGraph,
197  status,
198  "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: SetAttribute CUDNN_ATTR_OPERATIONGRAPH_HANDLE Failed");
199  return std::move(m_operationGraph);
200  }
201 
202  // Finalizing the descriptor
203  status = cudnnBackendFinalize(m_operationGraph.pointer->get_backend_descriptor());
204  if (status != CUDNN_STATUS_SUCCESS) {
206  &m_operationGraph, status, "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: cudnnFinalize Failed");
207  return std::move(m_operationGraph);
208  }
209 
210  return std::move(m_operationGraph);
211  }
212 
213  explicit OperationGraphBuilder_v8() = default;
214  ~OperationGraphBuilder_v8() = default;
218  operator=(OperationGraphBuilder_v8 const &) = delete;
219 
220  private:
222 };
223 }
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setHandle(cudnnHandle_t handle_) -> OperationGraphBuilder_v8 &
Set cudnnHandle for the operations.
auto setOperationGraph(int64_t numOps_, Operation_v8 const **ops_) -> OperationGraphBuilder_v8 &
Set numoperations and the operations.
ManagedOpaqueDescriptor get_desc() const
Returns a copy of underlying managed descriptor.
auto getEngineCount(void) const -> int64_t
Query the total count of the engines for the Operation Set.
cudnnStatus_t get_status() const
Current status of the descriptor.
std::string describe() const override
Return a string describing the backend Descriptor.
const char * get_error() const
Diagonistic error message if any.
OperationGraph_v8 & operator=(OperationGraph_v8 const &)=delete
std::array< ManagedOpaqueDescriptor, 10 > ops
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.