From 6dcd9689b161f89f8b1db99bfca4718087b9ba72 Mon Sep 17 00:00:00 2001 From: Saad Khan Date: Thu, 12 May 2022 13:59:28 +0530 Subject: [PATCH 1/6] add basic db_connection code Signed-off-by: Saad Khan --- src/bayes_optuna/optuna_hpo.py | 6 ++- src/db/hpo.db | Bin 0 -> 28672 bytes src/db_connection.py | 76 +++++++++++++++++++++++++++++++++ src/rest_service.py | 7 +-- 4 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 src/db/hpo.db create mode 100644 src/db_connection.py diff --git a/src/bayes_optuna/optuna_hpo.py b/src/bayes_optuna/optuna_hpo.py index a4a7d7a..a1c9751 100644 --- a/src/bayes_optuna/optuna_hpo.py +++ b/src/bayes_optuna/optuna_hpo.py @@ -134,8 +134,12 @@ def recommend(self): elif self.hpo_algo_impl == "optuna_skopt": sampler = optuna.integration.SkoptSampler() + study_name = self.experiment_name + storage_name = "sqlite:///./src/db/hpo.db" + # Create a study object - study = optuna.create_study(direction=self.direction, sampler=sampler, study_name=self.experiment_name) + study = optuna.create_study(direction=self.direction, sampler=sampler, study_name=study_name, + storage=storage_name) # Execute an optimization by using an 'Objective' instance study.optimize(Objective(self), n_trials=self.total_trials, n_jobs=self.parallel_trials) diff --git a/src/db/hpo.db b/src/db/hpo.db new file mode 100644 index 0000000000000000000000000000000000000000..96d1fd0b37fa254c867411862afbcb871ebb4d43 GIT binary patch literal 28672 zcmeI%(QDH{9Ki8p?OGY_Ht-=&g>)}#VJLC+;6wugbxFDBEm7wG%SwxX)aOI=qKr4U`^ z=!%PM7vo0sRF~>m)q?V5uUp;vsgx=^%I3GN@73MS%gRpWM^*s+2q1s}0tg_000Iag z@IM7UAFr3ST1}mYfp<0%&Pm67Wn2CDS!(qx)3J0%_Udu(hQ3-F!*jjearADleQ5Sh z^>^0kVo2x6(K{!$y_XINLSMeo^}Vr(#=bXut+!^c@!WXUl%e;Cyzxe16#CjeGM%M} z=F?J1Yt>X`=%0(*+4U&e0=ZroOcrf7PQo?fT6zAWP}U4X{kpqAus(LJC=aXSxCfSF zw(Wkprbm~NYZJCoweOgR z-RAyUNqbYj2kjE6+{-_b|0-2^+m<&90tg_000IagfB*srAb009ILKmY** L5I_I{1Q7TIkCO!^ literal 0 HcmV?d00001 diff --git a/src/db_connection.py b/src/db_connection.py new file mode 100644 index 0000000..e2b96c8 --- /dev/null +++ b/src/db_connection.py @@ -0,0 +1,76 @@ +""" +Copyright (c) 2020, 2022 Red Hat, IBM Corporation and others. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import sqlite3 + +from datetime import date + + +def conn_create(): + conn = sqlite3.connect('./db/hpo.db') + + print("Opened database successfully") + + try: + conn.execute('''CREATE TABLE EXPERIMENT + (name VARCHAR(512) PRIMARY KEY NOT NULL, + search_space VARCHAR(512) NOT NULL, + objective_function VARCHAR(512) NOT NULL, + created_at TIMESTAMP);''') + except: + sqlite3.OperationalError + + print("EXPERIMENT Table created successfully") + try: + conn.execute('''CREATE TABLE EXPERIMENT_DETAILS + (id INT PRIMARY KEY NOT NULL, + experiment_name VARCHAR(512) NOT NULL, + tunable_id INT NOT NULL, + results FLOAT NOT NULL, + created_at DATETIME);''') + except: + sqlite3.OperationalError + + print("EXPERIMENT_DETAILS Table created successfully") + + try: + conn.execute('''CREATE TABLE TUNABLES + (id INT PRIMARY KEY NOT NULL, + tunable_name VARCHAR(512) NOT NULL, + tunable_value FLOAT NOT NULL);''') + except: + sqlite3.OperationalError + + print("TUNABLES Table created successfully") + + +class DBConnectionHandler: + + def insert_data(self, exp_name, search_space, obj_function): + + self.execute("INSERT INTO EXPERIMENT (NAME,SEARCH_SPACE,OBJECTIVE_FUNCTION,CREATED_AT) " + "VALUES (?, ?, ?)", (exp_name, search_space, obj_function, date.today())) + self.commit() + + print("Records created successfully") + + def get_data(self): + cursor = self.execute("SELECT name, search_space, objective_function from EXPERIMENT") + for row in cursor: + print("NAME = ", row[0]) + print("SEARCH_SPACE = ", row[1]) + print("OBJECTIVE_FUNCTION = ", row[2], "\n") + + print("Operation done successfully") diff --git a/src/rest_service.py b/src/rest_service.py index 5cb6ced..2ebeb3f 100644 --- a/src/rest_service.py +++ b/src/rest_service.py @@ -27,6 +27,7 @@ from logger import get_logger import hpo_service +import db_connection logger = get_logger(__name__) @@ -113,13 +114,13 @@ def getHomeScreen(self): def handle_generate_new_operation(self, json_object): """Process EXP_TRIAL_GENERATE_NEW operation.""" is_valid_json_object = validate_trial_generate_json(json_object) - - if is_valid_json_object and hpo_service.instance.doesNotContainExperiment( - json_object["search_space"]["experiment_name"]): + experiment_name = json_object["search_space"]["experiment_name"] + if is_valid_json_object and hpo_service.instance.doesNotContainExperiment(experiment_name): search_space_json = json_object["search_space"] if str(search_space_json["experiment_name"]).isspace() or not str(search_space_json["experiment_name"]): self._set_response(400, "-1") return + db_connection.DBConnectionHandler.insert_data(experiment_name, search_space_json, search_space_json["objective_function"]) get_search_create_study(search_space_json, json_object["operation"]) trial_number = hpo_service.instance.get_trial_number(json_object["search_space"]["experiment_name"]) self._set_response(200, str(trial_number)) From 2cbf5fdf44b1c56c48939cdc8d7c12192f2ab135 Mon Sep 17 00:00:00 2001 From: Saad Khan Date: Fri, 13 May 2022 13:58:47 +0530 Subject: [PATCH 2/6] update db_connection.py to insert data in experiment and experiment_details table Signed-off-by: Saad Khan --- src/db/hpo.db => hpo.db | Bin 28672 -> 28672 bytes src/bayes_optuna/optuna_hpo.py | 4 +- src/db_connection.py | 74 ++++++++++++++++++++------------- src/rest_service.py | 4 +- 4 files changed, 49 insertions(+), 33 deletions(-) rename src/db/hpo.db => hpo.db (94%) diff --git a/src/db/hpo.db b/hpo.db similarity index 94% rename from src/db/hpo.db rename to hpo.db index 96d1fd0b37fa254c867411862afbcb871ebb4d43..1f134cd649c5b49b5f2dc1b4b963f35e834f7317 100644 GIT binary patch literal 28672 zcmeI)Z%@-e90%~$VK)(+&-RF5@vfn;KdG6|kbEy~+FtK>z5Db+a#x_U(Xu(yUwc81a=oTpR#a8F zqw9*IXkyQbJxLjHGLc+}quO_#cB&~0rS<9jPeq%$s$BY#|2BQ?l0S8I>UUZI@*w~L z2tWV=5P$##AOL~WAnAU)z_Z1TBXsDqsoX&drXvl+GsC7Y_#XjuE*{ zok)bGNY1hcw=&s+VW@}27=i8h zEUKdSI6z(xwku#^5P$##AOHafKmY;|fB*ze zt-#@JRqK2XZ_oOSn~v?;W;vui-(lri`Fee6Y0<3L7jKkZe|yW?E1t}bNu<~*ATh!3Jl@9K9k{=7m2kaS-~6M zBh=aQh%K5eZ+X64pCDWm6t_p`%d!j)s2kFF`b@;agzvc_BV4S*#7#O{m=XnF?xjB0 z@!qq5YDzF6itE?ld~!XanAjNC8jV?U+u z8KxhlKj5VL%0k81e9UVO0`<9)XPR~ zQjBdi|3eWQ5(FRs0SG_<0uX=z1Rwwb2tWV=|C_*-Oy|NFt2^=c|NM7FY)BA*00bZa z0SG_<0uX=z1Rwwb2%LU_nyTeGx!gFv3o`uu|Cf^gb^1Gq7D5055P$##AOHafKmY;| zfB*y_@b3bdOzwjG{xAON9})y0009U<00Izz00bZa0SG_<0%ur2{{27Z|7Z9GqiYa= W00bZa0SG_<0uX=z1RwwbA@B#cwEK<# delta 379 zcmZp8z}WDBae_1>^F$eENoEE;od90`9}H|fN(_8!`497|^JVZT@qF7XD6ooqvK9}& zT1cp$qmz$oFpOq%4svx2aa9Nbi7LQEC&%;bt%q^q4Ph+yCOc+!aYIAK7WI Date: Mon, 16 May 2022 13:52:16 +0530 Subject: [PATCH 3/6] Add new table for configs, modified experiment_details Signed-off-by: Saad Khan --- hpo.db | Bin 28672 -> 0 bytes src/bayes_optuna/optuna_hpo.py | 7 ++++ src/db_connection.py | 65 ++++++++++++++++++++++----------- src/rest_service.py | 5 ++- 4 files changed, 54 insertions(+), 23 deletions(-) delete mode 100644 hpo.db diff --git a/hpo.db b/hpo.db deleted file mode 100644 index 1f134cd649c5b49b5f2dc1b4b963f35e834f7317..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 28672 zcmeI)Z%@-e90%~$VK)(+&-RF5@vfn;KdG6|kbEy~+FtK>z5Db+a#x_U(Xu(yUwc81a=oTpR#a8F zqw9*IXkyQbJxLjHGLc+}quO_#cB&~0rS<9jPeq%$s$BY#|2BQ?l0S8I>UUZI@*w~L z2tWV=5P$##AOL~WAnAU)z_Z1TBXsDqsoX&drXvl+GsC7Y_#XjuE*{ zok)bGNY1hcw=&s+VW@}27=i8h zEUKdSI6z(xwku#^5P$##AOHafKmY;|fB*ze zt-#@JRqK2XZ_oOSn~v?;W;vui-(lri`Fee6Y0<3L7jKkZe|yW?E1t}bNu<~*ATh!3Jl@9K9k{=7m2kaS-~6M zBh=aQh%K5eZ+X64pCDWm6t_p`%d!j)s2kFF`b@;agzvc_BV4S*#7#O{m=XnF?xjB0 z@!qq5YDzF6itE?ld~!XanAjNC8jV?U+u z8KxhlKj5VL%0k81e9UVO0`<9)XPR~ zQjBdi|3eWQ5(FRs0SG_<0uX=z1Rwwb2tWV=|C_*-Oy|NFt2^=c|NM7FY)BA*00bZa z0SG_<0uX=z1Rwwb2%LU_nyTeGx!gFv3o`uu|Cf^gb^1Gq7D5055P$##AOHafKmY;| zfB*y_@b3bdOzwjG{xAON9})y0009U<00Izz00bZa0SG_<0%ur2{{27Z|7Z9GqiYa= W00bZa0SG_<0uX=z1RwwbA@B#cwEK<# diff --git a/src/bayes_optuna/optuna_hpo.py b/src/bayes_optuna/optuna_hpo.py index 3e19054..452e049 100644 --- a/src/bayes_optuna/optuna_hpo.py +++ b/src/bayes_optuna/optuna_hpo.py @@ -19,6 +19,8 @@ from logger import get_logger +from .. import db_connection + logger = get_logger(__name__) trials = [] @@ -181,6 +183,11 @@ def recommend(self): logger.info("Recommended config: " + str(recommended_config)) + # call db function to store the configs + + db_connection.insert_config_details(self.experiment_name, study.best_trial, study.best_params, study.best_value, + recommended_config) + class Objective(TrialDetails): """ diff --git a/src/db_connection.py b/src/db_connection.py index 788e5bd..be4e49f 100644 --- a/src/db_connection.py +++ b/src/db_connection.py @@ -19,11 +19,9 @@ from datetime import date db_path = os.path.abspath("hpo.db") -conn = None def conn_create(): - global conn conn = sqlite3.connect(db_path) print("Opened database successfully") @@ -35,46 +33,69 @@ def conn_create(): search_space TEXT NOT NULL, objective_function VARCHAR(512) NOT NULL, created_at TIMESTAMP);''') - print("EXPERIMENT Table created successfully") + print("Experiment Table created successfully") except: sqlite3.OperationalError try: cursor.execute('''CREATE TABLE experiment_details - (id INT PRIMARY KEY NOT NULL, - experiment_name VARCHAR(512) NOT NULL, - tunable_id INT NOT NULL, - results FLOAT NOT NULL, - created_at DATETIME);''') - print("EXPERIMENT_DETAILS Table created successfully") + (id INTEGER PRIMARY KEY, + trial_number INTEGER, + experiment_name VARCHAR NOT NULL, + trial_json VARCHAR, + results_value FLOAT, + trial_result VARCHAR, + created_at DATETIME, + FOREIGN KEY (experiment_name) + REFERENCES experiment (name));''') + print("Experiment_Details Table created successfully") except: sqlite3.OperationalError try: - cursor.execute('''CREATE TABLE tunables - (id INT PRIMARY KEY NOT NULL, - tunable_name VARCHAR(512) NOT NULL, - tunable_value FLOAT NOT NULL);''') + cursor.execute('''CREATE TABLE configs + (id INTEGER PRIMARY KEY, + experiment_name VARCHAR, + best_config VARCHAR, + best_parameter VARCHAR, + best_value FLOAT, + recommended_config VARCHAR, + FOREIGN KEY (experiment_name) + REFERENCES experiment (name));''') print("Tunables Table created successfully") except: sqlite3.OperationalError -def insert_data(experiment_name, search_space_json, obj_function): - global conn +def insert_experiment_data(experiment_name, search_space_json, obj_function): conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("INSERT INTO experiment (name,search_space,objective_function,created_at) " "VALUES (?, ?, ?, ?)", (experiment_name, str(search_space_json), obj_function, date.today())) - result = cursor.execute("select id from experiment_details") - if result.rowcount == 0: - id = 0 - else: - id = result.rowcount + 1 - cursor.execute("INSERT INTO experiment_details (id,experiment_name,tunable_id,results,created_at) " - "VALUES (?, ?, ?,?,?)", (id, experiment_name, 0, 0.0, date.today())) conn.commit() + print("Record created successfully") + +def insert_experiment_details(json_object, trial_json): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("INSERT INTO experiment_details (trial_number,experiment_name,trial_json,results_value," + "trial_result, created_at) " + "VALUES (?, ?, ?, ?, ?, ?)", (json_object["trial_number"], json_object["experiment_name"], + trial_json, json_object["result_value"], + json_object["trial_result"], date.today())) + conn.commit() + print("Record created successfully") + + +def insert_config_details(experiment_name, best_config, best_parameter, best_value, recommended_config): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("INSERT INTO configs (experiment_name, best_config, best_parameter, best_value, " + "recommended_config, created_at) " + "VALUES (?, ?, ?, ?, ?, ?)", (experiment_name, best_config, best_parameter, best_value, + recommended_config, date.today())) + conn.commit() print("Record created successfully") diff --git a/src/rest_service.py b/src/rest_service.py index 61ffbe6..470f57e 100644 --- a/src/rest_service.py +++ b/src/rest_service.py @@ -122,7 +122,7 @@ def handle_generate_new_operation(self, json_object): return obj_function = search_space_json["objective_function"] db_connection.conn_create() - db_connection.insert_data(experiment_name, search_space_json, obj_function) + db_connection.insert_experiment_data(experiment_name, search_space_json, obj_function) get_search_create_study(search_space_json, json_object["operation"]) trial_number = hpo_service.instance.get_trial_number(json_object["search_space"]["experiment_name"]) self._set_response(200, str(trial_number)) @@ -149,6 +149,9 @@ def handle_result_operation(self, json_object): hpo_service.instance.set_result(json_object["experiment_name"], json_object["trial_result"], json_object["result_value_type"], json_object["result_value"]) + # call db function to store experiment details after each trial + trial_json = hpo_service.instance.get_trial_json_object(json_object["experiment_name"]) + db_connection.insert_experiment_details(json_object, trial_json) self._set_response(200, "0") else: self._set_response(400, "-1") From 12ab6cf75e0b962eff9e955d760abae64c5c6f11 Mon Sep 17 00:00:00 2001 From: Saad Khan Date: Wed, 18 May 2022 21:21:21 +0530 Subject: [PATCH 4/6] Add recommendations api to fetch configs from DB and other bug fixes Signed-off-by: Saad Khan --- hpo.db | Bin 0 -> 20480 bytes src/bayes_optuna/optuna_hpo.py | 6 +- src/db_connection.py | 113 -------------------------- src/db_files/__init__.py | 0 src/db_files/db_connection.py | 143 +++++++++++++++++++++++++++++++++ src/rest_service.py | 47 +++++++++-- 6 files changed, 188 insertions(+), 121 deletions(-) create mode 100644 hpo.db delete mode 100644 src/db_connection.py create mode 100644 src/db_files/__init__.py create mode 100644 src/db_files/db_connection.py diff --git a/hpo.db b/hpo.db new file mode 100644 index 0000000000000000000000000000000000000000..f92adfeeab03ee9c190e4456e51b590e4b16cabd GIT binary patch literal 20480 zcmeHPOK;mo5T<0uZ%4iPkPG9Ykp;*EmlP>VsL&{`{6PGwq%;l!gP^6g&6Xl%k;=n{ z4{Fl*+@H`(PC*9afadU{xBQ*n;D&6XRiMo z`eo)XDA0!@KoOt_Py{Ff6ak6=MSvpk5g_m^JUTvi<3`}cXO6aK;DTlD>j!q@LqTQ7ENf}r45;h&x)JmBB zI-OlwO=o*h?qSlpM$Ko)5py0Pt(U^Z^qt`Hu_KmC;}7#iU396<3wQe#m6ZZm5L%j5%@3= zc#eW&xwH5e=`wZ-hHmPGMO!PC4ZJ8U#wA6G6(lK^STxIz4vHt?_D^>=+7Bn$@Brh& zQPw+BFOqQ@5=7GBSdM0>Zh&V)y9iBWlfW>rrxih~8y>38(=sN*i&hmjHwc5nvZZRq zfu-uOY|>?wovNv+4iG|jcdSEboQh`Jnj8A45LT8g)5fX;`xx$$U^p-*%XxU>JFj6q z#uatXs+vR|B=ImA35f~&5-wSlr&)Ydg-s8_vRa01k`}SZ6Jp!JWis+K$__k>XO2R- z>OVo7J!%==oLv zY{u?kGvbgo(>;!W=CQVx&BR;bn+B#B7! zvcDvVLy5Q;kJVO3`j}UDKsS`r^Rc0nf;gFoa^IE2$LXb{AaOB;=i^dsCFpNnmn2Cd zF_a_%*v&;g50W5GfjkTW93|jyUY8>Y8rtQ>3(ZSPBB^oy@8@YKn@P|GW0g3=c zfFeKLE)rym9R=H+jfdr{hR*JGnZoEpA6(Jis z#tR8qh2R+m;8rKQkymM7A5|xpA#W>|EGDnKs($>cI zdM2M~Xfw5GnK*J1Zsj7o`Rkq}?f^s zbKe|pGEI{UvBYNed>TDK7H(|PG_u~CvUK`GbG*kfT=qTqwO8kH9q)MXp=Pyf`0nYc zch^QqBEFq0e6}5^JM}>?r|yn>de}ONG+}bj${-J)gptiCvWHWHd$ycf9MY*lVhNNW zC!`n>hxKVNRWTlA<4RoN<9%J)SIqx~OM}AR+NIs=1%tSBx0$^Umj*lZ^&#kZTbFLw P%djp@Jji!BeO>w=_qc1! literal 0 HcmV?d00001 diff --git a/src/bayes_optuna/optuna_hpo.py b/src/bayes_optuna/optuna_hpo.py index 452e049..8284e4a 100644 --- a/src/bayes_optuna/optuna_hpo.py +++ b/src/bayes_optuna/optuna_hpo.py @@ -19,7 +19,7 @@ from logger import get_logger -from .. import db_connection +from db_files import db_connection logger = get_logger(__name__) @@ -185,8 +185,8 @@ def recommend(self): # call db function to store the configs - db_connection.insert_config_details(self.experiment_name, study.best_trial, study.best_params, study.best_value, - recommended_config) + db_connection.insert_config_details(self.experiment_name, str(study.best_params), study.best_value, + str(study.best_trial), str(recommended_config)) class Objective(TrialDetails): diff --git a/src/db_connection.py b/src/db_connection.py deleted file mode 100644 index be4e49f..0000000 --- a/src/db_connection.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Copyright (c) 2020, 2022 Red Hat, IBM Corporation and others. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -import os -import sqlite3 - -from datetime import date - -db_path = os.path.abspath("hpo.db") - - -def conn_create(): - conn = sqlite3.connect(db_path) - print("Opened database successfully") - - cursor = conn.cursor() - - try: - cursor.execute('''CREATE TABLE experiment - (name VARCHAR(512) PRIMARY KEY NOT NULL, - search_space TEXT NOT NULL, - objective_function VARCHAR(512) NOT NULL, - created_at TIMESTAMP);''') - print("Experiment Table created successfully") - except: - sqlite3.OperationalError - - try: - cursor.execute('''CREATE TABLE experiment_details - (id INTEGER PRIMARY KEY, - trial_number INTEGER, - experiment_name VARCHAR NOT NULL, - trial_json VARCHAR, - results_value FLOAT, - trial_result VARCHAR, - created_at DATETIME, - FOREIGN KEY (experiment_name) - REFERENCES experiment (name));''') - print("Experiment_Details Table created successfully") - except: - sqlite3.OperationalError - - try: - cursor.execute('''CREATE TABLE configs - (id INTEGER PRIMARY KEY, - experiment_name VARCHAR, - best_config VARCHAR, - best_parameter VARCHAR, - best_value FLOAT, - recommended_config VARCHAR, - FOREIGN KEY (experiment_name) - REFERENCES experiment (name));''') - print("Tunables Table created successfully") - except: - sqlite3.OperationalError - - -def insert_experiment_data(experiment_name, search_space_json, obj_function): - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute("INSERT INTO experiment (name,search_space,objective_function,created_at) " - "VALUES (?, ?, ?, ?)", (experiment_name, str(search_space_json), obj_function, date.today())) - conn.commit() - print("Record created successfully") - - -def insert_experiment_details(json_object, trial_json): - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute("INSERT INTO experiment_details (trial_number,experiment_name,trial_json,results_value," - "trial_result, created_at) " - "VALUES (?, ?, ?, ?, ?, ?)", (json_object["trial_number"], json_object["experiment_name"], - trial_json, json_object["result_value"], - json_object["trial_result"], date.today())) - conn.commit() - print("Record created successfully") - - -def insert_config_details(experiment_name, best_config, best_parameter, best_value, recommended_config): - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute("INSERT INTO configs (experiment_name, best_config, best_parameter, best_value, " - "recommended_config, created_at) " - "VALUES (?, ?, ?, ?, ?, ?)", (experiment_name, best_config, best_parameter, best_value, - recommended_config, date.today())) - conn.commit() - print("Record created successfully") - - -def get_data(): - global conn - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - result = cursor.execute("SELECT name, search_space, objective_function from experiment") - for row in result: - print("NAME = ", row[0]) - print("SEARCH_SPACE = ", row[1]) - print("OBJECTIVE_FUNCTION = ", row[2], "\n") - - print("Operation done successfully") diff --git a/src/db_files/__init__.py b/src/db_files/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/db_files/db_connection.py b/src/db_files/db_connection.py new file mode 100644 index 0000000..507ef5a --- /dev/null +++ b/src/db_files/db_connection.py @@ -0,0 +1,143 @@ +""" +Copyright (c) 2020, 2022 Red Hat, IBM Corporation and others. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import json +import os +import sqlite3 + +from datetime import datetime + +db_path = os.path.abspath("hpo.db") + + +def conn_create(): + conn = sqlite3.connect(db_path) + + cursor = conn.cursor() + + try: + cursor.execute('''CREATE TABLE experiment + (name VARCHAR(512) PRIMARY KEY NOT NULL, + search_space TEXT NOT NULL, + objective_function VARCHAR(512) NOT NULL, + created_at TIMESTAMP);''') + print("Experiment Table created successfully") + except: + sqlite3.OperationalError + + try: + cursor.execute('''CREATE TABLE experiment_details + (id INTEGER PRIMARY KEY, + trial_number INTEGER, + experiment_name VARCHAR NOT NULL, + trial_config VARCHAR, + results_value FLOAT, + trial_result_status VARCHAR, + created_at TIMESTAMP, + FOREIGN KEY (experiment_name) + REFERENCES experiment (name));''') + print("Experiment_Details Table created successfully") + except: + sqlite3.OperationalError + + try: + cursor.execute('''CREATE TABLE configs + (id INTEGER PRIMARY KEY, + experiment_name VARCHAR, + best_parameter VARCHAR, + best_value FLOAT, + best_trial VARCHAR, + recommended_config VARCHAR, + created_at TIMESTAMP, + FOREIGN KEY (experiment_name) + REFERENCES experiment (name));''') + print("Tunables Table created successfully") + except: + sqlite3.OperationalError + + conn.close() + + +def insert_experiment_data(experiment_name, search_space_json, obj_function): + conn = sqlite3.connect(db_path) + try: + conn.execute("INSERT INTO experiment (name,search_space,objective_function,created_at) " + "VALUES (?, ?, ?, ?)", (experiment_name, str(search_space_json), obj_function, datetime.now())) + except: + sqlite3.IntegrityError + return "Experiment already exists!" + + conn.commit() + print("Record created successfully") + conn.close() + + +def insert_experiment_details(json_object, trial_json): + conn = sqlite3.connect(db_path) + conn.execute("INSERT INTO experiment_details (trial_number,experiment_name,trial_config,results_value," + "trial_result_status, created_at) " + "VALUES (?, ?, ?, ?, ?, ?)", (json_object["trial_number"], json_object["experiment_name"], + trial_json, json_object["result_value"], + json_object["trial_result"], datetime.now())) + conn.commit() + print("Record created successfully") + conn.close() + + +def insert_config_details(experiment_name, best_parameter, best_value, best_trial, recommended_config): + conn = sqlite3.connect(db_path) + conn.execute("INSERT INTO configs (experiment_name, best_parameter, best_value, best_trial, recommended_config, " + "created_at) VALUES (?, ?, ?, ?, ?, ?)", (experiment_name, best_parameter, best_value, best_trial, + recommended_config, datetime.now())) + conn.commit() + print("Final config records are inserted successfully for experiment: ", experiment_name) + conn.close() + + +def get_recommended_configs(trial_number, experiment_name): + json_list = [] + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # check if the requested experiment is present in DB + cursor.execute("SELECT EXISTS(SELECT 1 from experiment_details where experiment_name=:experiment_name) ", + {"experiment_name": experiment_name}) + query_result = cursor.fetchall()[0][0] + if query_result == 0: + return "Experiment not found" + + # check if the requested trials has been completed or not + cursor.execute("SELECT count(trial_number) from experiment_details where experiment_name=:experiment_name ", + {"experiment_name": experiment_name}) + query_result = cursor.fetchall()[0][0] + if query_result < trial_number: + return "Trials not completed yet or exceeds the provided trial limit" + + print("Fetching best configs from top {} trials...\n".format(trial_number)) + + result = cursor.execute("SELECT experiment_name, trial_number, trial_config, results_value,trial_result_status " + "from experiment_details where experiment_name=:experiment_name and trial_number " + "between 0 and :trial_number order by results_value", + {"experiment_name": experiment_name, "trial_number": trial_number - 1}) + + rank = 1 + for row in result.fetchall(): + json_dict = {'Rank': rank, 'Experiment_Name': row[0], 'Trial_Number': row[1], 'Trial_Config': row[2], + 'Results_Value': row[3], 'Trial_Result_Status': row[4]} + json_list.append(json_dict) + rank += 1 + + conn.close() + return json.dumps(json_list) diff --git a/src/rest_service.py b/src/rest_service.py index 470f57e..9318ae0 100644 --- a/src/rest_service.py +++ b/src/rest_service.py @@ -27,7 +27,7 @@ from logger import get_logger import hpo_service -import db_connection +from db_files import db_connection logger = get_logger(__name__) @@ -37,6 +37,7 @@ search_space_json = [] api_endpoint = "/experiment_trials" +api_endpoint_recommendation = "/recommendations" host_name = "0.0.0.0" server_port = 8085 @@ -98,12 +99,39 @@ def do_GET(self): data = hpo_service.instance.get_trial_json_object(query["experiment_name"][0]) self._set_response(200, data) else: - self._set_response(404, "-1") + self._set_response(404, "Invalid URL or missing required parameters!") + elif re.search(api_endpoint_recommendation, self.path): + query = parse_qs(urlparse(self.path).query) + # check if the request contains 'experiment_name' and 'trials' + if "experiment_name" in query and "trials" in query: + self.getRecommendations(query) + else: + self._set_response(404, "Invalid URL or missing required parameters!") elif self.path == "/": data = self.getHomeScreen() self._set_response(200, data) else: - self._set_response(403, "-1") + self._set_response(404, "Error! The requested resource could not be found.") + + def getRecommendations(self, query): + experiment = query["experiment_name"][0] + trial_result_needed = int(query["trials"][0]) + if trial_result_needed <= 0: + data = "Invalid Trials value. Should be greater than 0" + logger.error(data) + self._set_response(403, data) + return + + # call database to fetch the configs + db_response = db_connection.get_recommended_configs(trial_result_needed, experiment) + + # check if the response is valid JSON else return the corresponding error response + try: + json.loads(db_response) + self._set_response(200, db_response) + except ValueError: + logger.error(db_response) + self._set_response(403, db_response) def getHomeScreen(self): fin = open(welcome_page) @@ -121,8 +149,15 @@ def handle_generate_new_operation(self, json_object): self._set_response(400, "-1") return obj_function = search_space_json["objective_function"] + + # call db function to open a connection and insert data in experiments table db_connection.conn_create() - db_connection.insert_experiment_data(experiment_name, search_space_json, obj_function) + response = db_connection.insert_experiment_data(experiment_name, search_space_json, obj_function) + if response: + logger.error(response) + self._set_response(403, response) + return + get_search_create_study(search_space_json, json_object["operation"]) trial_number = hpo_service.instance.get_trial_number(json_object["search_space"]["experiment_name"]) self._set_response(200, str(trial_number)) @@ -149,9 +184,11 @@ def handle_result_operation(self, json_object): hpo_service.instance.set_result(json_object["experiment_name"], json_object["trial_result"], json_object["result_value_type"], json_object["result_value"]) - # call db function to store experiment details after each trial trial_json = hpo_service.instance.get_trial_json_object(json_object["experiment_name"]) + + # call db_files function to store experiment details after each trial db_connection.insert_experiment_details(json_object, trial_json) + self._set_response(200, "0") else: self._set_response(400, "-1") From e86f0c6a1b9a6db7f15efc5238805e1502c5eec7 Mon Sep 17 00:00:00 2001 From: Saad Khan Date: Wed, 1 Jun 2022 13:58:25 +0530 Subject: [PATCH 5/6] Update code to support PostGreSQL Signed-off-by: Saad Khan --- {src/db_files => db}/__init__.py | 0 db/config.py | 37 ++++++++ db/operations.py | 146 +++++++++++++++++++++++++++++++ db/pg_connection.py | 40 +++++++++ db/tables.py | 70 +++++++++++++++ hpo.db | Bin 20480 -> 0 bytes src/bayes_optuna/optuna_hpo.py | 5 -- src/db_files/db_connection.py | 143 ------------------------------ src/rest_service.py | 23 +++-- 9 files changed, 308 insertions(+), 156 deletions(-) rename {src/db_files => db}/__init__.py (100%) create mode 100644 db/config.py create mode 100644 db/operations.py create mode 100644 db/pg_connection.py create mode 100644 db/tables.py delete mode 100644 hpo.db delete mode 100644 src/db_files/db_connection.py diff --git a/src/db_files/__init__.py b/db/__init__.py similarity index 100% rename from src/db_files/__init__.py rename to db/__init__.py diff --git a/db/config.py b/db/config.py new file mode 100644 index 0000000..ed2d1ac --- /dev/null +++ b/db/config.py @@ -0,0 +1,37 @@ +""" +Copyright (c) 2020, 2022 Red Hat, IBM Corporation and others. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +from configparser import ConfigParser + +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +def config(filename=BASE_DIR + '/db/database.ini', section='postgresql'): + # create a parser + parser = ConfigParser() + # read config file + parser.read(filename) + + # get section, default to postgresql + db = {} + if parser.has_section(section): + params = parser.items(section) + for param in params: + db[param[0]] = param[1] + else: + raise Exception('Section {0} not found in the {1} file'.format(section, filename)) + + return db diff --git a/db/operations.py b/db/operations.py new file mode 100644 index 0000000..bfee971 --- /dev/null +++ b/db/operations.py @@ -0,0 +1,146 @@ +""" +Copyright (c) 2020, 2022 Red Hat, IBM Corporation and others. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import json +import psycopg2 + +from datetime import datetime +from logger import get_logger + +import pg_connection + +rank = 1 + +logger = get_logger(__name__) + + +def insert_experiment_data(experiment_name, search_space_json, obj_function): + """ insert a new experiment, search_space_json and objective_function into the experiment table """ + sql = """INSERT INTO experiment(experiment_name,search_space,objective_function,created_at) + VALUES(%s, %s, %s, %s);""" + conn = None + try: + experiment_name = experiment_name.replace("-", "_") + conn = pg_connection.connect_to_pg() + # create a new cursor + cur = conn.cursor() + # execute the INSERT statement + cur.execute(sql, (experiment_name, str(search_space_json), obj_function, datetime.now())) + + # commit the changes to the database + conn.commit() + # close communication with the database + cur.close() + except (Exception, psycopg2.DatabaseError) as error: + logger.error(error) + return error + finally: + if conn is not None: + conn.close() + + +def insert_trial_details(json_object, trial_json): + """ insert experiment's trial details """ + sql = """INSERT INTO experiment_trial_details (trial_number, rank, experiment_name, trial_config, results_value, + trial_result_status, created_at) VALUES(%s, %s, %s, %s, %s, %s, %s);""" + conn = None + global rank + try: + experiment_name = str(json_object["experiment_name"]).replace("-", "_") + conn = pg_connection.connect_to_pg() + # create a new cursor + cur = conn.cursor() + # execute the INSERT statement + cur.execute(sql, (json_object["trial_number"], rank, experiment_name, trial_json, + json_object["result_value"], json_object["trial_result"], datetime.now())) + + # commit the changes to the database + conn.commit() + rank += 1 + # close communication with the database + cur.close() + # call the function to sort the result_value and update the rank column + response = update_rank(conn, experiment_name) + if response: + return response + except (Exception, psycopg2.DatabaseError) as error: + logger.error(error) + return error + finally: + if conn is not None: + conn.close() + + +def update_rank(conn, experiment_name): + cur = conn.cursor() + sql = "select results_value from experiment_trial_details where experiment_name = '{}' order by results_value"\ + .format(experiment_name) + cur.execute(sql) + new_rank = 1 + try: + for row in cur.fetchall(): + sql = "UPDATE experiment_trial_details SET rank = {} where results_value = {} and experiment_name = '{}'"\ + .format(new_rank, row[0], experiment_name) + cur.execute(sql) + new_rank += 1 + conn.commit() + except (Exception, psycopg2.DatabaseError) as error: + logger.error(error) + return error + + +def get_recommended_configs(trial_number, experiment_name): + conn = None + json_list = [] + try: + conn = pg_connection.connect_to_pg() + cur = conn.cursor() + + # check if the requested experiment is present in DB + sql = "SELECT EXISTS(SELECT 1 from experiment_trial_details where experiment_name = '{}')"\ + .format(experiment_name) + cur.execute(sql) + query_result = cur.fetchall()[0][0] + if query_result == 0: + return "Experiment not found" + + # check if the requested trials has been completed or not + sql = "SELECT count(trial_number) from experiment_trial_details where experiment_name = '{}'"\ + .format(experiment_name) + cur.execute(sql) + query_result = cur.fetchall()[0][0] + if query_result < trial_number: + return "Trials not completed yet or exceeds the provided trial limit" + + print("Fetching best configs from top {} trials...\n".format(trial_number)) + sql = "SELECT trial_number,rank,experiment_name,trial_config, results_value,trial_result_status from " \ + "experiment_trial_details where experiment_name = '{}' and trial_number between 0 and {} order by rank"\ + .format(experiment_name, trial_number - 1) + + cur.execute(sql) + for row in cur.fetchall(): + json_dict = {'Trial_Number': row[0], 'Rank': row[1], 'Experiment_Name': row[2], 'Trial_Config': row[3], + 'Results_Value': row[4], 'Trial_Result_Status': row[5]} + json_list.append(json_dict) + + cur.close() + + except (Exception, psycopg2.DatabaseError) as error: + logger.error(error) + return error + finally: + if conn is not None: + conn.close() + return json.dumps(json_list) diff --git a/db/pg_connection.py b/db/pg_connection.py new file mode 100644 index 0000000..933cf81 --- /dev/null +++ b/db/pg_connection.py @@ -0,0 +1,40 @@ +""" +Copyright (c) 2020, 2022 Red Hat, IBM Corporation and others. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import psycopg2 + +from config import config +from logger import get_logger +logger = get_logger(__name__) + + +def connect_to_pg(): + """ Connect to the PostgreSQL database server """ + conn = None + try: + # read connection parameters + params = config() + + # connect to the PostgreSQL server + logger.info('Connecting to the PostgreSQL database...') + conn = psycopg2.connect(**params) + logger.info('Successfully Connected!') + + return conn + except (Exception, psycopg2.DatabaseError) as error: + logger.error(error) + if conn is not None: + conn.close() + logger.info('Database connection closed.') diff --git a/db/tables.py b/db/tables.py new file mode 100644 index 0000000..28911fd --- /dev/null +++ b/db/tables.py @@ -0,0 +1,70 @@ +""" +Copyright (c) 2020, 2022 Red Hat, IBM Corporation and others. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import sys +import psycopg2 +from logger import get_logger + +file_dir = os.path.dirname(__file__) +sys.path.append(file_dir) +import pg_connection + +logger = get_logger(__name__) + + +def create_tables(): + """ create tables in the PostgreSQL database""" + commands = ( + """ + CREATE TABLE experiment ( + experiment_name VARCHAR(512) PRIMARY KEY NOT NULL, + search_space TEXT NOT NULL, + objective_function VARCHAR(512) NOT NULL, + created_at TIMESTAMP + ) + """, + """ + CREATE TABLE experiment_trial_details ( + id SERIAL PRIMARY KEY, + trial_number INTEGER, + rank INTEGER, + experiment_name VARCHAR NOT NULL, + trial_config VARCHAR, + results_value FLOAT, + trial_result_status VARCHAR, + created_at TIMESTAMP, + FOREIGN KEY (experiment_name) + REFERENCES experiment (experiment_name) + ON UPDATE CASCADE ON DELETE CASCADE + ) + """) + conn = None + try: + conn = pg_connection.connect_to_pg() + cur = conn.cursor() + # create table one by one + for command in commands: + cur.execute(command) + logger.info("Tables created") + # close communication with the PostgreSQL database server + cur.close() + # commit the changes + conn.commit() + except (Exception, psycopg2.DatabaseError) as error: + logger.error(error) + finally: + if conn is not None: + conn.close() diff --git a/hpo.db b/hpo.db deleted file mode 100644 index f92adfeeab03ee9c190e4456e51b590e4b16cabd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20480 zcmeHPOK;mo5T<0uZ%4iPkPG9Ykp;*EmlP>VsL&{`{6PGwq%;l!gP^6g&6Xl%k;=n{ z4{Fl*+@H`(PC*9afadU{xBQ*n;D&6XRiMo z`eo)XDA0!@KoOt_Py{Ff6ak6=MSvpk5g_m^JUTvi<3`}cXO6aK;DTlD>j!q@LqTQ7ENf}r45;h&x)JmBB zI-OlwO=o*h?qSlpM$Ko)5py0Pt(U^Z^qt`Hu_KmC;}7#iU396<3wQe#m6ZZm5L%j5%@3= zc#eW&xwH5e=`wZ-hHmPGMO!PC4ZJ8U#wA6G6(lK^STxIz4vHt?_D^>=+7Bn$@Brh& zQPw+BFOqQ@5=7GBSdM0>Zh&V)y9iBWlfW>rrxih~8y>38(=sN*i&hmjHwc5nvZZRq zfu-uOY|>?wovNv+4iG|jcdSEboQh`Jnj8A45LT8g)5fX;`xx$$U^p-*%XxU>JFj6q z#uatXs+vR|B=ImA35f~&5-wSlr&)Ydg-s8_vRa01k`}SZ6Jp!JWis+K$__k>XO2R- z>OVo7J!%==oLv zY{u?kGvbgo(>;!W=CQVx&BR;bn+B#B7! zvcDvVLy5Q;kJVO3`j}UDKsS`r^Rc0nf;gFoa^IE2$LXb{AaOB;=i^dsCFpNnmn2Cd zF_a_%*v&;g50W5GfjkTW93|jyUY8>Y8rtQ>3(ZSPBB^oy@8@YKn@P|GW0g3=c zfFeKLE)rym9R=H+jfdr{hR*JGnZoEpA6(Jis z#tR8qh2R+m;8rKQkymM7A5|xpA#W>|EGDnKs($>cI zdM2M~Xfw5GnK*J1Zsj7o`Rkq}?f^s zbKe|pGEI{UvBYNed>TDK7H(|PG_u~CvUK`GbG*kfT=qTqwO8kH9q)MXp=Pyf`0nYc zch^QqBEFq0e6}5^JM}>?r|yn>de}ONG+}bj${-J)gptiCvWHWHd$ycf9MY*lVhNNW zC!`n>hxKVNRWTlA<4RoN<9%J)SIqx~OM}AR+NIs=1%tSBx0$^Umj*lZ^&#kZTbFLw P%djp@Jji!BeO>w=_qc1! diff --git a/src/bayes_optuna/optuna_hpo.py b/src/bayes_optuna/optuna_hpo.py index 8284e4a..8903f22 100644 --- a/src/bayes_optuna/optuna_hpo.py +++ b/src/bayes_optuna/optuna_hpo.py @@ -183,11 +183,6 @@ def recommend(self): logger.info("Recommended config: " + str(recommended_config)) - # call db function to store the configs - - db_connection.insert_config_details(self.experiment_name, str(study.best_params), study.best_value, - str(study.best_trial), str(recommended_config)) - class Objective(TrialDetails): """ diff --git a/src/db_files/db_connection.py b/src/db_files/db_connection.py deleted file mode 100644 index 507ef5a..0000000 --- a/src/db_files/db_connection.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Copyright (c) 2020, 2022 Red Hat, IBM Corporation and others. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -import json -import os -import sqlite3 - -from datetime import datetime - -db_path = os.path.abspath("hpo.db") - - -def conn_create(): - conn = sqlite3.connect(db_path) - - cursor = conn.cursor() - - try: - cursor.execute('''CREATE TABLE experiment - (name VARCHAR(512) PRIMARY KEY NOT NULL, - search_space TEXT NOT NULL, - objective_function VARCHAR(512) NOT NULL, - created_at TIMESTAMP);''') - print("Experiment Table created successfully") - except: - sqlite3.OperationalError - - try: - cursor.execute('''CREATE TABLE experiment_details - (id INTEGER PRIMARY KEY, - trial_number INTEGER, - experiment_name VARCHAR NOT NULL, - trial_config VARCHAR, - results_value FLOAT, - trial_result_status VARCHAR, - created_at TIMESTAMP, - FOREIGN KEY (experiment_name) - REFERENCES experiment (name));''') - print("Experiment_Details Table created successfully") - except: - sqlite3.OperationalError - - try: - cursor.execute('''CREATE TABLE configs - (id INTEGER PRIMARY KEY, - experiment_name VARCHAR, - best_parameter VARCHAR, - best_value FLOAT, - best_trial VARCHAR, - recommended_config VARCHAR, - created_at TIMESTAMP, - FOREIGN KEY (experiment_name) - REFERENCES experiment (name));''') - print("Tunables Table created successfully") - except: - sqlite3.OperationalError - - conn.close() - - -def insert_experiment_data(experiment_name, search_space_json, obj_function): - conn = sqlite3.connect(db_path) - try: - conn.execute("INSERT INTO experiment (name,search_space,objective_function,created_at) " - "VALUES (?, ?, ?, ?)", (experiment_name, str(search_space_json), obj_function, datetime.now())) - except: - sqlite3.IntegrityError - return "Experiment already exists!" - - conn.commit() - print("Record created successfully") - conn.close() - - -def insert_experiment_details(json_object, trial_json): - conn = sqlite3.connect(db_path) - conn.execute("INSERT INTO experiment_details (trial_number,experiment_name,trial_config,results_value," - "trial_result_status, created_at) " - "VALUES (?, ?, ?, ?, ?, ?)", (json_object["trial_number"], json_object["experiment_name"], - trial_json, json_object["result_value"], - json_object["trial_result"], datetime.now())) - conn.commit() - print("Record created successfully") - conn.close() - - -def insert_config_details(experiment_name, best_parameter, best_value, best_trial, recommended_config): - conn = sqlite3.connect(db_path) - conn.execute("INSERT INTO configs (experiment_name, best_parameter, best_value, best_trial, recommended_config, " - "created_at) VALUES (?, ?, ?, ?, ?, ?)", (experiment_name, best_parameter, best_value, best_trial, - recommended_config, datetime.now())) - conn.commit() - print("Final config records are inserted successfully for experiment: ", experiment_name) - conn.close() - - -def get_recommended_configs(trial_number, experiment_name): - json_list = [] - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - # check if the requested experiment is present in DB - cursor.execute("SELECT EXISTS(SELECT 1 from experiment_details where experiment_name=:experiment_name) ", - {"experiment_name": experiment_name}) - query_result = cursor.fetchall()[0][0] - if query_result == 0: - return "Experiment not found" - - # check if the requested trials has been completed or not - cursor.execute("SELECT count(trial_number) from experiment_details where experiment_name=:experiment_name ", - {"experiment_name": experiment_name}) - query_result = cursor.fetchall()[0][0] - if query_result < trial_number: - return "Trials not completed yet or exceeds the provided trial limit" - - print("Fetching best configs from top {} trials...\n".format(trial_number)) - - result = cursor.execute("SELECT experiment_name, trial_number, trial_config, results_value,trial_result_status " - "from experiment_details where experiment_name=:experiment_name and trial_number " - "between 0 and :trial_number order by results_value", - {"experiment_name": experiment_name, "trial_number": trial_number - 1}) - - rank = 1 - for row in result.fetchall(): - json_dict = {'Rank': rank, 'Experiment_Name': row[0], 'Trial_Number': row[1], 'Trial_Config': row[2], - 'Results_Value': row[3], 'Trial_Result_Status': row[4]} - json_list.append(json_dict) - rank += 1 - - conn.close() - return json.dumps(json_list) diff --git a/src/rest_service.py b/src/rest_service.py index 9318ae0..acc626d 100644 --- a/src/rest_service.py +++ b/src/rest_service.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - +import sys from http.server import BaseHTTPRequestHandler, HTTPServer import re import cgi @@ -27,7 +27,10 @@ from logger import get_logger import hpo_service -from db_files import db_connection + +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, BASE_DIR) +from db import tables, operations logger = get_logger(__name__) @@ -123,7 +126,7 @@ def getRecommendations(self, query): return # call database to fetch the configs - db_response = db_connection.get_recommended_configs(trial_result_needed, experiment) + db_response = operations.get_recommended_configs(trial_result_needed, experiment) # check if the response is valid JSON else return the corresponding error response try: @@ -151,11 +154,11 @@ def handle_generate_new_operation(self, json_object): obj_function = search_space_json["objective_function"] # call db function to open a connection and insert data in experiments table - db_connection.conn_create() - response = db_connection.insert_experiment_data(experiment_name, search_space_json, obj_function) + tables.create_tables() + response = operations.insert_experiment_data(experiment_name, search_space_json, obj_function) + if response: - logger.error(response) - self._set_response(403, response) + self._set_response(403, "-1") return get_search_create_study(search_space_json, json_object["operation"]) @@ -187,7 +190,11 @@ def handle_result_operation(self, json_object): trial_json = hpo_service.instance.get_trial_json_object(json_object["experiment_name"]) # call db_files function to store experiment details after each trial - db_connection.insert_experiment_details(json_object, trial_json) + response = operations.insert_trial_details(json_object, trial_json) + + if response: + self._set_response(403, "-1") + return self._set_response(200, "0") else: From 0101f2c9adc3d0958fe62c4d29a8b3ac4d670af0 Mon Sep 17 00:00:00 2001 From: Saad Khan Date: Wed, 1 Jun 2022 16:13:18 +0530 Subject: [PATCH 6/6] Update query to fetch recommendations, bug fixes Signed-off-by: Saad Khan --- db/operations.py | 43 +++++++++++++++++++++------------- db/tables.py | 2 +- src/bayes_optuna/optuna_hpo.py | 1 - src/rest_service.py | 14 +++++------ 4 files changed, 35 insertions(+), 25 deletions(-) diff --git a/db/operations.py b/db/operations.py index bfee971..1e9bca6 100644 --- a/db/operations.py +++ b/db/operations.py @@ -83,6 +83,7 @@ def insert_trial_details(json_object, trial_json): conn.close() +# Update rank in the table based on the results_value def update_rank(conn, experiment_name): cur = conn.cursor() sql = "select results_value from experiment_trial_details where experiment_name = '{}' order by results_value"\ @@ -113,23 +114,32 @@ def get_recommended_configs(trial_number, experiment_name): .format(experiment_name) cur.execute(sql) query_result = cur.fetchall()[0][0] - if query_result == 0: - return "Experiment not found" - - # check if the requested trials has been completed or not - sql = "SELECT count(trial_number) from experiment_trial_details where experiment_name = '{}'"\ - .format(experiment_name) - cur.execute(sql) - query_result = cur.fetchall()[0][0] - if query_result < trial_number: - return "Trials not completed yet or exceeds the provided trial limit" + if not query_result: + return + + query = "SELECT trial_number,rank,experiment_name,trial_config, results_value,trial_result_status from " \ + "experiment_trial_details " + # if the trial value is 0, return all the records sorted by rank + if trial_number == 0: + logger.info("Fetching all the trials based on rank...") + query += "order by rank" + cur.execute(query) + # else fetch the records based on requested trial_number + else: + # check if the requested trials has been completed or not + sql = "SELECT count(trial_number) from experiment_trial_details where experiment_name = '{}'" \ + .format(experiment_name) + cur.execute(sql) + query_result = cur.fetchall()[0][0] + if query_result < trial_number: + logger.info("\nRequested trial exceeds the completed trial limit!") + trial_number = query_result - print("Fetching best configs from top {} trials...\n".format(trial_number)) - sql = "SELECT trial_number,rank,experiment_name,trial_config, results_value,trial_result_status from " \ - "experiment_trial_details where experiment_name = '{}' and trial_number between 0 and {} order by rank"\ - .format(experiment_name, trial_number - 1) + logger.info("Fetching best configs from top {} trials...\n".format(trial_number)) + query += "where experiment_name = '{}' and rank <= {} order by rank".format(experiment_name, trial_number) + cur.execute(query) - cur.execute(sql) + # Store the result fetched from DB in a dictionary and create JSON from it for row in cur.fetchall(): json_dict = {'Trial_Number': row[0], 'Rank': row[1], 'Experiment_Name': row[2], 'Trial_Config': row[3], 'Results_Value': row[4], 'Trial_Result_Status': row[5]} @@ -143,4 +153,5 @@ def get_recommended_configs(trial_number, experiment_name): finally: if conn is not None: conn.close() - return json.dumps(json_list) + if json_list: + return json.dumps(json_list) diff --git a/db/tables.py b/db/tables.py index 28911fd..5b6e84a 100644 --- a/db/tables.py +++ b/db/tables.py @@ -39,7 +39,7 @@ def create_tables(): """ CREATE TABLE experiment_trial_details ( id SERIAL PRIMARY KEY, - trial_number INTEGER, + trial_number VARCHAR, rank INTEGER, experiment_name VARCHAR NOT NULL, trial_config VARCHAR, diff --git a/src/bayes_optuna/optuna_hpo.py b/src/bayes_optuna/optuna_hpo.py index 8903f22..1e46621 100644 --- a/src/bayes_optuna/optuna_hpo.py +++ b/src/bayes_optuna/optuna_hpo.py @@ -19,7 +19,6 @@ from logger import get_logger -from db_files import db_connection logger = get_logger(__name__) diff --git a/src/rest_service.py b/src/rest_service.py index acc626d..3be7d3a 100644 --- a/src/rest_service.py +++ b/src/rest_service.py @@ -117,16 +117,16 @@ def do_GET(self): self._set_response(404, "Error! The requested resource could not be found.") def getRecommendations(self, query): - experiment = query["experiment_name"][0] + experiment_name = str(query["experiment_name"][0]).replace("-", "_") trial_result_needed = int(query["trials"][0]) - if trial_result_needed <= 0: - data = "Invalid Trials value. Should be greater than 0" + if trial_result_needed < 0: + data = "Invalid Trials value" logger.error(data) self._set_response(403, data) return - # call database to fetch the configs - db_response = operations.get_recommended_configs(trial_result_needed, experiment) + # call database operations function to fetch the configs + db_response = operations.get_recommended_configs(trial_result_needed, experiment_name) # check if the response is valid JSON else return the corresponding error response try: @@ -134,7 +134,7 @@ def getRecommendations(self, query): self._set_response(200, db_response) except ValueError: logger.error(db_response) - self._set_response(403, db_response) + self._set_response(403, "-1") def getHomeScreen(self): fin = open(welcome_page) @@ -189,7 +189,7 @@ def handle_result_operation(self, json_object): json_object["result_value"]) trial_json = hpo_service.instance.get_trial_json_object(json_object["experiment_name"]) - # call db_files function to store experiment details after each trial + # call db operations function to store experiment details after each trial response = operations.insert_trial_details(json_object, trial_json) if response: