CUDNN API  8
cudnn_frontend_find_plan.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #pragma once
24 
26 #include <map>
27 
28 namespace cudnn_frontend {
29 
35 template <CudnnFindSamplingTechnique samplingTechnique>
36 auto
37 time_sorted_plan(cudnnHandle_t handle, executionPlans_t plans, VariantPack &variantPack) -> executionOptions_t {
38  executionOptions_t time_sorted_plans;
39  std::map<float, ExecutionPlan &> timed_execution_plans;
40 
41  const int maxIterCount =
43  ? 1
44  : (samplingTechnique == CudnnFindSamplingTechnique::CUDNN_FIND_SAMPLE_MEDIAN_OF_THREE) ? 3 : 100;
45  const float threshhold = 0.95f;
46 
47  cudaEvent_t start, stop;
48  cudaEventCreate(&start);
49  cudaEventCreate(&stop);
50  cudaDeviceSynchronize();
51 
52  for (auto &plan : plans) {
53  float time_ms = 0.0f;
54  float final_time_ms = 0.0f;
55  float min_time_ms = std::numeric_limits<float>::max();
56 
57  // Warm-up run
58  ::cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc());
59  cudaDeviceSynchronize();
60 
61  for (int i = 0; i < maxIterCount; i++) {
62  cudaEventRecord(start);
63 
64  ::cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc());
65 
66  cudaEventRecord(stop);
67  cudaEventSynchronize(stop);
68  cudaEventElapsedTime(&time_ms, start, stop);
69 
71  final_time_ms = std::min(min_time_ms, time_ms);
72  if (time_ms / min_time_ms < threshhold) {
73  min_time_ms = final_time_ms;
74  } else {
75  break;
76  }
77  } else {
78  final_time_ms = i == (maxIterCount / 2) ? time_ms : final_time_ms;
79  }
80  }
81  timed_execution_plans.insert({final_time_ms, plan});
82  }
83  std::transform(
84  timed_execution_plans.begin(),
85  timed_execution_plans.end(),
86  std::back_inserter(time_sorted_plans),
87  [](const std::map<float, cudnn_frontend::ExecutionPlan &>::value_type &pair) -> struct executionOption {
88  return {std::move(pair.second), pair.first};
89  });
90 
91  cudaEventDestroy(start);
92  cudaEventDestroy(stop);
93 
94  return time_sorted_plans;
95 }
96 
97 template <CudnnFindSamplingTechnique samplingTechnique>
98 auto
101  cudnn_frontend::VariantPack &variantPack,
102  Predicate pred) -> executionOptions_t {
104  executionPlans_t plans;
105  for (auto &engine_config : generate_engine_config(opGraph)) {
106 #ifndef NV_CUDNN_DISABLE_EXCEPTION
107  try {
108 #endif
109  plans.push_back(
110  cudnn_frontend::ExecutionPlanBuilder().setHandle(handle).setEngineConfig(engine_config).build());
111 #ifndef NV_CUDNN_DISABLE_EXCEPTION
112  } catch (cudnnException e) {
113  continue;
114  }
115 #endif
116  }
117  return time_sorted_plan<samplingTechnique>(handle, filter(pred, plans), variantPack);
118 }
119 }
std::function< bool(cudnn_frontend::ExecutionPlan const &plan)> Predicate
std::vector< cudnn_frontend::ExecutionPlan > executionPlans_t
std::vector< struct executionOption > executionOptions_t
Variety of renames.
auto filter(Predicate pred, executionPlans_t &plans) -> executionPlans_t
auto time_sorted_plan(cudnnHandle_t handle, executionPlans_t plans, VariantPack &variantPack) -> executionOptions_t
auto cudnnFindPlan(cudnnHandle_t handle, cudnn_frontend::OperationGraph &&opGraph, cudnn_frontend::VariantPack &variantPack, Predicate pred) -> executionOptions_t
Sample once quick but may have unstable values.