CUDNN Frontend API  8.2.0
cudnn_frontend_Errata.h
Go to the documentation of this file.
1 
2 /*
3  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice shall be included in
13  * all copies or substantial portions of the Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21  * DEALINGS IN THE SOFTWARE.
22  */
23 
25 
26 #include <cstdlib>
27 #include <fstream>
28 #pragma once
29 
31 
32 namespace cudnn_frontend {
33 
34 // Loads the json handle from the json file
35 // json file is defined by environment variable
36 // CUDNN_ERRATA_JSON_FILE. If the environment variable
37 // is not set the value set in the API is considered.
38 static void
39 load_from_config(json &json_handle, const std::string & errata_json) {
40  const char * err_json = std::getenv("CUDNN_ERRATA_JSON_FILE");
41  if (err_json == NULL && errata_json == "") {return;}
42  if (err_json == NULL) { err_json = errata_json.c_str();}
43  std::ifstream ifs(err_json, std::ifstream::in);
44  if (!ifs.is_open() || !ifs.good()) {return;}
45  ifs >> json_handle;
46  return;
47 }
48 
49 template <typename T>
50 static bool
51 check_rule(const json &json_handle, const std::string & executionPlanTag,
52  cudnnHandle_t handle, T fn) {
53  std::string operation = json_handle["operation"];
54  std::string engine = json_handle["engine"];
55  uint64_t cudnn_start = 0;
56  uint64_t cudnn_end = -1;
57  if (json_handle.contains("cudnn_version_start")) {
58  cudnn_start = json_handle["cudnn_version_start"];
59  }
60  if (json_handle.contains("cudnn_version_end")) {
61  cudnn_end = json_handle["cudnn_version_end"];
62  }
63  std::string tag_prefix = operation + "_" + engine;
64  bool blocked =
65  std::equal(tag_prefix.begin(), tag_prefix.end(), executionPlanTag.begin()) &&
66  CUDNN_VERSION >= cudnn_start &&
67  CUDNN_VERSION < cudnn_end;
68 
69  if (blocked && json_handle.contains("knob")) { // Short circuit if operation and engine do not match
70  for (auto& kv : json_handle["knob"]) {
71  blocked = blocked &&
72  (executionPlanTag.find(kv) != std::string::npos);
73  }
74  }
75 
76  blocked = blocked && fn();
77  return blocked;
78 
79  (void) handle;
80 }
81 
82 // Takes in an initialzed json handle and checks if it satisfies the
83 // condition for running it. Returns true if the given executionPlanTag
84 // is faulty.
85 template <typename T>
86 static bool
87 check_errata(const json &json_handle, const std::string & executionPlanTag,
88  cudnnHandle_t handle, T fn) {
89 
90  for (auto const &rule : json_handle["rules"]) {
91  if (check_rule<T>(rule, executionPlanTag, handle, fn)) {
92  return true;
93  }
94  }
95 
96  return false;
97 }
98 
99 }
bool contains(KeyT &&key) const
check the existence of an element in a JSON object
Definition: json.hpp:21558
a class to store JSON values
Definition: json.hpp:3366
static bool check_errata(const json &json_handle, const std::string &executionPlanTag, cudnnHandle_t handle, T fn)
static void load_from_config(json &json_handle, const std::string &errata_json)
static bool check_rule(const json &json_handle, const std::string &executionPlanTag, cudnnHandle_t handle, T fn)
j template void())
Definition: json.hpp:4061
basic_json<> json
default JSON class
Definition: json.hpp:3390