diff --git a/src/acc/opencl/acc_opencl.c b/src/acc/opencl/acc_opencl.c index bea1e2e7e03..0832a3af6e2 100644 --- a/src/acc/opencl/acc_opencl.c +++ b/src/acc/opencl/acc_opencl.c @@ -212,6 +212,7 @@ int c_dbcsr_acc_init(void) { const char* const env_dump = (NULL != env_dump_acc ? env_dump_acc : getenv("IGC_ShaderDumpEnable")); # if defined(ACC_OPENCL_NCCS) && (0 < ACC_OPENCL_NCCS) const char *const env_zex = getenv("ZEX_NUMBER_OF_CCS"), *const env_nccs = getenv("ACC_OPENCL_NCCS"); + const char* const env_flt = getenv("ZE_FLAT_DEVICE_HIERARCHY"); const int nccs = (NULL == env_nccs ? 0 : atoi(env_nccs)); # endif # if defined(ACC_OPENCL_IENV) @@ -245,7 +246,9 @@ int c_dbcsr_acc_init(void) { c_dbcsr_acc_opencl_config.timer = c_dbcsr_acc_opencl_timer_host; } # if defined(ACC_OPENCL_NCCS) && (0 < ACC_OPENCL_NCCS) - if ((NULL == env_zex && 0 == (4 & c_dbcsr_acc_opencl_config.xhints)) || 0 != nccs) { + if ((NULL == env_zex && NULL == env_flt && 0 == (4 & c_dbcsr_acc_opencl_config.xhints)) || + (0 == LIBXSMM_PUTENV("ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE") && 0 != nccs)) + { static char zex_number_of_ccs[ACC_OPENCL_DEVICES_MAXCOUNT * 8 + 32] = "ZEX_NUMBER_OF_CCS="; int j = strlen(zex_number_of_ccs); for (i = 0; i < ACC_OPENCL_DEVICES_MAXCOUNT; ++i) { diff --git a/src/acc/opencl/smm/tune_multiply.py b/src/acc/opencl/smm/tune_multiply.py index d9d05559186..3efd4c72aa4 100755 --- a/src/acc/opencl/smm/tune_multiply.py +++ b/src/acc/opencl/smm/tune_multiply.py @@ -29,6 +29,7 @@ default_basename = "tune_multiply" default_mnk = "23x23x23" default_dbg = False +default_retry = 3 def env_intvalue(env, default, lookup=True): @@ -424,8 +425,9 @@ def merge_jsons(self, filenames): strkey = self.args.csvsep.join([str(k) for k in key]) strval = self.args.csvsep.join([str(v) for v in value[:-1]]) file.write("{}{}{}\n".format(strkey, self.args.csvsep, strval)) - retsld = retcnt = delsld = delcnt = 0 - retain, delete = [], [] + retsld, delsld = [0, 0, 0], [0, 0, 0] # [min, geo, max] + retain, delete = [], [] # lists of filenames + retcnt = delcnt = 0 # geo-counter for key, value in worse.items(): gflops = round(merged[key][1]) mtime = os.path.getmtime(merged[key][-1]) @@ -433,33 +435,50 @@ def merge_jsons(self, filenames): s = 0 if 0 < gflops: g = int(filename.split("-")[-1].split("g")[0]) - s = math.log(gflops / g) # slowdown + s = gflops / g # slowdown if mtime < os.path.getmtime(filename): - retsld, retcnt = retsld + s, retcnt + 1 + if 0 < s: + retsld[1] = retsld[1] + math.log(s) + retsld[0] = min(retsld[0], s) if 0 < retsld[0] else s + retsld[2] = max(retsld[2], s) + retcnt = retcnt + 1 retain.append(filename) else: - delsld, delcnt = delsld + s, delcnt + 1 + if 0 < s: + delsld[1] = delsld[1] + math.log(s) + delsld[0] = min(delsld[0], s) if 0 < delsld[0] else s + delsld[2] = max(delsld[2], s) + delcnt = delcnt + 1 delete.append(filename) if not self.args.nogflops: - slr = round(math.exp(retsld / retcnt), 1) if 0 < retcnt else 1 - sld = round(math.exp(delsld / delcnt), 1) if 0 < delcnt else 1 + retsld[1] = math.exp(retsld[1] / retcnt) if 0 < retcnt else 1 + delsld[1] = math.exp(delsld[1] / delcnt) if 0 < delcnt else 1 if not self.args.delete: if retain: num, lst = len(retain), " ".join(retain) - msg = "Worse and newer (retain {}@{}x): {}" - print(msg.format(num, slr, lst)) + msg = "Worse and newer (retain {} @ {}x): {}" + rnd = [str(round(i, 2)) for i in retsld] + print(msg.format(num, "..".join(rnd), lst)) if delete: num, lst = len(delete), " ".join(delete) - msg = "Worse and older (delete {}@{}x): {}" - print(msg.format(num, sld, lst)) + msg = "Worse and older (delete {} @ {}x): {}" + rnd = [str(round(i, 2)) for i in delsld] + print(msg.format(num, "..".join(rnd), lst)) else: for file in retain + delete: try: os.remove(file) except: # noqa: E722 pass - msg = " ({}..{}x)".format(slr, sld) if 1 < max(slr, sld) else "" - print("Removed outperformed parameters{}.".format(msg)) + msl = round(min(retsld[0], delsld[0]), 2) + xsl = round(max(retsld[2], delsld[2]), 2) + geo = round(math.sqrt(retsld[1] * delsld[1]), 2) + msg = "Removed outperformed parameter sets{}.".format( + " ({} @ {}..{}..{}x)".format(retcnt + delcnt, msl, geo, xsl) + if 0 < msl + else "" + ) + print(msg) elif bool(worse): print("WARNING: incorrectly merged duplicates") print(" due to nogflops argument!") @@ -554,7 +573,10 @@ def handle_sigint(self, signum, frame): self.handle_sigint_counter = self.handle_sigint_counter + 1 msg = "\nWARNING: tuning {}-kernel interrupted." print(msg.format("x".join(map(str, self.mnk)))) - self.save_final_config(self.config) + try: + self.save_final_config(self.config) + except: # noqa: E722 + pass exit(1) @@ -831,12 +853,16 @@ def handle_sigint(self, signum, frame): args.mb = 64 instance = SmmTuner(args) if not default_dbg: - try: - TuningRunMain(instance, args).main() - except Exception as e: - print("{}: {}".format(type(e).__name__, e)) - print("WARNING: ignored above error!") - instance.save_final_config(None, True) - pass + for retry in range(default_retry): + try: + TuningRunMain(instance, args).main() + exit(0) + except Exception as e: + msg = "IGNORED {} of {} {}: {}".format( + retry, default_retry, type(e).__name__, e + ) + print(msg) + pass + instance.save_final_config(None, True) else: TuningRunMain(instance, args).main()