Skip to content

Commit

Permalink
ocl: improved device setup
Browse files Browse the repository at this point in the history
* Control ZE_FLAT_DEVICE_HIERARCHY.
* Improved info output.
  • Loading branch information
hfp committed Nov 8, 2023
1 parent 4eec57f commit b11d8cb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 22 deletions.
5 changes: 4 additions & 1 deletion src/acc/opencl/acc_opencl.c
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ int c_dbcsr_acc_init(void) {
const char* const env_dump = (NULL != env_dump_acc ? env_dump_acc : getenv("IGC_ShaderDumpEnable"));
# if defined(ACC_OPENCL_NCCS) && (0 < ACC_OPENCL_NCCS)
const char *const env_zex = getenv("ZEX_NUMBER_OF_CCS"), *const env_nccs = getenv("ACC_OPENCL_NCCS");
const char* const env_flt = getenv("ZE_FLAT_DEVICE_HIERARCHY");
const int nccs = (NULL == env_nccs ? 0 : atoi(env_nccs));
# endif
# if defined(ACC_OPENCL_IENV)
Expand Down Expand Up @@ -245,7 +246,9 @@ int c_dbcsr_acc_init(void) {
c_dbcsr_acc_opencl_config.timer = c_dbcsr_acc_opencl_timer_host;
}
# if defined(ACC_OPENCL_NCCS) && (0 < ACC_OPENCL_NCCS)
if ((NULL == env_zex && 0 == (4 & c_dbcsr_acc_opencl_config.xhints)) || 0 != nccs) {
if ((NULL == env_zex && NULL == env_flt && 0 == (4 & c_dbcsr_acc_opencl_config.xhints)) ||
(0 == LIBXSMM_PUTENV("ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE") && 0 != nccs))
{
static char zex_number_of_ccs[ACC_OPENCL_DEVICES_MAXCOUNT * 8 + 32] = "ZEX_NUMBER_OF_CCS=";
int j = strlen(zex_number_of_ccs);
for (i = 0; i < ACC_OPENCL_DEVICES_MAXCOUNT; ++i) {
Expand Down
68 changes: 47 additions & 21 deletions src/acc/opencl/smm/tune_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
default_basename = "tune_multiply"
default_mnk = "23x23x23"
default_dbg = False
default_retry = 3


def env_intvalue(env, default, lookup=True):
Expand Down Expand Up @@ -424,42 +425,60 @@ def merge_jsons(self, filenames):
strkey = self.args.csvsep.join([str(k) for k in key])
strval = self.args.csvsep.join([str(v) for v in value[:-1]])
file.write("{}{}{}\n".format(strkey, self.args.csvsep, strval))
retsld = retcnt = delsld = delcnt = 0
retain, delete = [], []
retsld, delsld = [0, 0, 0], [0, 0, 0] # [min, geo, max]
retain, delete = [], [] # lists of filenames
retcnt = delcnt = 0 # geo-counter
for key, value in worse.items():
gflops = round(merged[key][1])
mtime = os.path.getmtime(merged[key][-1])
for filename in value:
s = 0
if 0 < gflops:
g = int(filename.split("-")[-1].split("g")[0])
s = math.log(gflops / g) # slowdown
s = gflops / g # slowdown
if mtime < os.path.getmtime(filename):
retsld, retcnt = retsld + s, retcnt + 1
if 0 < s:
retsld[1] = retsld[1] + math.log(s)
retsld[0] = min(retsld[0], s) if 0 < retsld[0] else s
retsld[2] = max(retsld[2], s)
retcnt = retcnt + 1
retain.append(filename)
else:
delsld, delcnt = delsld + s, delcnt + 1
if 0 < s:
delsld[1] = delsld[1] + math.log(s)
delsld[0] = min(delsld[0], s) if 0 < delsld[0] else s
delsld[2] = max(delsld[2], s)
delcnt = delcnt + 1
delete.append(filename)
if not self.args.nogflops:
slr = round(math.exp(retsld / retcnt), 1) if 0 < retcnt else 1
sld = round(math.exp(delsld / delcnt), 1) if 0 < delcnt else 1
retsld[1] = math.exp(retsld[1] / retcnt) if 0 < retcnt else 1
delsld[1] = math.exp(delsld[1] / delcnt) if 0 < delcnt else 1
if not self.args.delete:
if retain:
num, lst = len(retain), " ".join(retain)
msg = "Worse and newer (retain {}@{}x): {}"
print(msg.format(num, slr, lst))
msg = "Worse and newer (retain {} @ {}x): {}"
rnd = [str(round(i, 2)) for i in retsld]
print(msg.format(num, "..".join(rnd), lst))
if delete:
num, lst = len(delete), " ".join(delete)
msg = "Worse and older (delete {}@{}x): {}"
print(msg.format(num, sld, lst))
msg = "Worse and older (delete {} @ {}x): {}"
rnd = [str(round(i, 2)) for i in delsld]
print(msg.format(num, "..".join(rnd), lst))
else:
for file in retain + delete:
try:
os.remove(file)
except: # noqa: E722
pass
msg = " ({}..{}x)".format(slr, sld) if 1 < max(slr, sld) else ""
print("Removed outperformed parameters{}.".format(msg))
msl = round(min(retsld[0], delsld[0]), 2)
xsl = round(max(retsld[2], delsld[2]), 2)
geo = round(math.sqrt(retsld[1] * delsld[1]), 2)
msg = "Removed outperformed parameter sets{}.".format(
" ({} @ {}..{}..{}x)".format(retcnt + delcnt, msl, geo, xsl)
if 0 < msl
else ""
)
print(msg)
elif bool(worse):
print("WARNING: incorrectly merged duplicates")
print(" due to nogflops argument!")
Expand Down Expand Up @@ -554,7 +573,10 @@ def handle_sigint(self, signum, frame):
self.handle_sigint_counter = self.handle_sigint_counter + 1
msg = "\nWARNING: tuning {}-kernel interrupted."
print(msg.format("x".join(map(str, self.mnk))))
self.save_final_config(self.config)
try:
self.save_final_config(self.config)
except: # noqa: E722
pass
exit(1)


Expand Down Expand Up @@ -831,12 +853,16 @@ def handle_sigint(self, signum, frame):
args.mb = 64
instance = SmmTuner(args)
if not default_dbg:
try:
TuningRunMain(instance, args).main()
except Exception as e:
print("{}: {}".format(type(e).__name__, e))
print("WARNING: ignored above error!")
instance.save_final_config(None, True)
pass
for retry in range(default_retry):
try:
TuningRunMain(instance, args).main()
exit(0)
except Exception as e:
msg = "IGNORED {} of {} {}: {}".format(
retry, default_retry, type(e).__name__, e
)
print(msg)
pass
instance.save_final_config(None, True)
else:
TuningRunMain(instance, args).main()

0 comments on commit b11d8cb

Please sign in to comment.