diff --git a/src/services/slack_service.py b/src/services/slack_service.py index cc3f2ba..7369154 100644 --- a/src/services/slack_service.py +++ b/src/services/slack_service.py @@ -6,28 +6,47 @@ # grabs the credentials from .env directly slack_app = App() + + users_map = {} + def send_message(channel: str, thread_ts: str, text: str): - try: - retry(lambda: slack_app.client.chat_postMessage( + response = retry( + lambda: slack_app.client.chat_postMessage( token=os.environ["SLACK_BOT_TOKEN"], channel=channel, text=text, thread_ts=thread_ts, - )) - except Exception as e: - print(e) + ) + ) + return response["ts"] + + +def update_message(channel: str, thread_ts: str, ts: str, text: str): + response = retry( + lambda: slack_app.client.chat_update( + channel=channel, ts=ts, thread_ts=thread_ts, text=text + ) + ) + return response["ts"] + + +def delete_message(channel: str, ts: str): + response = retry(lambda: slack_app.client.chat_delete(channel=channel, ts=ts)) + return response["ts"] def get_thread_messages(channel: str, thread_ts: str): try: - return retry(lambda: slack_app.client.conversations_replies( - token=os.environ["SLACK_BOT_TOKEN"], - channel=channel, - ts=thread_ts, - include_all_metadata=True, - )["messages"]) + return retry( + lambda: slack_app.client.conversations_replies( + token=os.environ["SLACK_BOT_TOKEN"], + channel=channel, + ts=thread_ts, + include_all_metadata=True, + )["messages"] + ) except Exception as e: print(e) @@ -56,6 +75,7 @@ def get_username(user_id: str): users_map[user_id] = user["user"]["name"] return users_map[user_id] + def find_bot_id(payload): for auth in payload["authorizations"]: # Check if the current authorization is a bot @@ -64,6 +84,7 @@ def find_bot_id(payload): return None + def get_sender(payload): return payload.get("event").get("user") @@ -92,6 +113,9 @@ def process_event_payload(payload): thread_to_reply = thread_ts if thread_ts != ts: thread_to_reply = ts + + msg_ts = send_message(channel, thread_to_reply, "Thinking...") + messages = f"USER: {text.replace(f'<@{bot_id}>','').strip()}" if thread_ts: messages = ( @@ -104,8 +128,9 @@ def process_event_payload(payload): end_time = time.perf_counter() print(f"response generated in {round(end_time - start_time, 2)}s") - return send_message(channel, thread_to_reply, response) + return update_message(channel, thread_to_reply, msg_ts, response) except Exception as error: # Improve error handling print(error) + delete_message(channel, msg_ts) return