diff --git a/.gitignore b/.gitignore index 3015c69..5b8b06c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .env venv/ +.venv/ __pycache__/ *.swp *.log diff --git a/README.md b/README.md index 7fbeb21..c99378c 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ If you want to use hackingBuddyGPT and need help selecting the best LLM for your ## hackingBuddyGPT in the News -- **upcoming** 2024-11-20: [Manuel Reinsperger](https://www.github.com/neverbolt) will present hackingBuddyGPT at the [European Symposium on Security and Artificial Intelligence (ESSAI)](https://essai-conference.eu/) +- 2024-11-20: [Manuel Reinsperger](https://www.github.com/neverbolt) presented hackingBuddyGPT at the [European Symposium on Security and Artificial Intelligence (ESSAI)](https://essai-conference.eu/) - 2024-07-26: The [GitHub Accelerator Showcase](https://github.blog/open-source/maintainers/github-accelerator-showcase-celebrating-our-second-cohort-and-whats-next/) features hackingBuddyGPT - 2024-07-24: [Juergen](https://github.com/citostyle) speaks at [Open Source + mezcal night @ GitHub HQ](https://lu.ma/bx120myg) - 2024-05-23: hackingBuddyGPT is part of [GitHub Accelerator 2024](https://github.blog/news-insights/company-news/2024-github-accelerator-meet-the-11-projects-shaping-open-source-ai/) @@ -82,38 +82,38 @@ template_next_cmd = Template(filename=str(template_dir / "next_cmd.txt")) class MinimalLinuxPrivesc(Agent): - conn: SSHConnection = None + _sliding_history: SlidingCliHistory = None + _max_history_size: int = 0 def init(self): super().init() + self._sliding_history = SlidingCliHistory(self.llm) + self._max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - self.llm.count_tokens(template_next_cmd.source) + self.add_capability(SSHRunCommand(conn=self.conn), default=True) self.add_capability(SSHTestCredential(conn=self.conn)) - self._template_size = self.llm.count_tokens(template_next_cmd.source) - def perform_round(self, turn: int) -> bool: - got_root: bool = False + @log_conversation("Asking LLM for a new command...") + def perform_round(self, turn: int, log: Logger) -> bool: + # get as much history as fits into the target context size + history = self._sliding_history.get_history(self._max_history_size) - with self._log.console.status("[bold green]Asking LLM for a new command..."): - # get as much history as fits into the target context size - history = self._sliding_history.get_history(self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size) + # get the next command from the LLM + answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) + message_id = log.call_response(answer) - # get the next command from the LLM - answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) - cmd = llm_util.cmd_output_fixer(answer.result) + # clean the command, load and execute it + cmd = llm_util.cmd_output_fixer(answer.result) + capability, arguments = cmd.split(" ", 1) + result, got_root = self.run_capability(message_id, "0", capability, arguments, calling_mode=CapabilityCallingMode.Direct, log=log) - with self._log.console.status("[bold green]Executing that command..."): - self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) - - # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) + # store the results in our local history self._sliding_history.add_command(cmd, result) - self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) - # if we got root, we can stop the loop + # signal if we were successful in our task return got_root @@ -306,6 +306,22 @@ Mac, Docker Desktop and Gemini-OpenAI-Proxy: * See https://github.com/ipa-lab/hackingBuddyGPT/blob/main/MAC.md +## Beta Features + +### Viewer + +The viewer is a simple web-based tool to view the results of hackingBuddyGPT runs. It is currently in beta and can be started with: + +```bash +$ hackingBuddyGPT Viewer +``` + +This will start a webserver on `http://localhost:4444` that can be accessed with a web browser. + +To log to this central viewer, you currently need to change the `GlobalLogger` definition in [./src/hackingBuddyGPT/utils/logging.py](src/hackingBuddyGPT/utils/logging.py) to `GlobalRemoteLogger`. + +This feature is not fully tested yet and therefore is not recommended to be exposed to the internet! + ## Publications about hackingBuddyGPT Given our background in academia, we have authored papers that lay the groundwork and report on our efforts: diff --git a/pyproject.toml b/pyproject.toml index 8873ca6..ec439db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,19 +26,25 @@ classifiers = [ "Development Status :: 4 - Beta", ] dependencies = [ - 'fabric == 3.2.2', - 'Mako == 1.3.2', - 'requests == 2.32.0', - 'rich == 13.7.1', - 'tiktoken == 0.8.0', - 'instructor == 1.3.5', - 'PyYAML == 6.0.1', - 'python-dotenv == 1.0.1', - 'pypsexec == 0.3.0', - 'pydantic == 2.8.2', - 'openai == 1.28.0', - 'BeautifulSoup4', - 'nltk' + 'fabric == 3.2.2', + 'Mako == 1.3.2', + 'requests == 2.32.0', + 'rich == 13.7.1', + 'tiktoken == 0.8.0', + 'instructor == 1.3.5', + 'PyYAML == 6.0.1', + 'python-dotenv == 1.0.1', + 'pypsexec == 0.3.0', + 'pydantic == 2.8.2', + 'openai == 1.28.0', + 'BeautifulSoup4', + 'nltk', + 'fastapi == 0.114.0', + 'fastapi-utils == 0.7.0', + 'jinja2 == 3.1.4', + 'uvicorn[standard] == 0.30.6', + 'dataclasses_json == 0.6.7', + 'websockets == 13.1', ] [project.urls] @@ -56,14 +62,9 @@ where = ["src"] [tool.pytest.ini_options] pythonpath = "src" -addopts = [ - "--import-mode=importlib", -] +addopts = ["--import-mode=importlib"] [project.optional-dependencies] -testing = [ - 'pytest', - 'pytest-mock' -] +testing = ['pytest', 'pytest-mock'] dev = [ 'ruff', ] diff --git a/src/hackingBuddyGPT/capabilities/capability.py b/src/hackingBuddyGPT/capabilities/capability.py index 7a4adbb..0459a09 100644 --- a/src/hackingBuddyGPT/capabilities/capability.py +++ b/src/hackingBuddyGPT/capabilities/capability.py @@ -38,14 +38,14 @@ def get_name(self) -> str: def __call__(self, *args, **kwargs): """ The actual execution of a capability, please make sure, that the parameters and return type of your - implementation are well typed, as this will make it easier to support full function calling soon. + implementation are well typed, as this is used to properly support function calling. """ pass def to_model(self) -> BaseModel: """ Converts the parameters of the `__call__` function of the capability to a pydantic model, that can be used to - interface with an LLM using eg instructor or the openAI function calling API. + interface with an LLM using eg the openAI function calling API. The model will have the same name as the capability class and will have the same fields as the `__call__`, the `__call__` method can then be accessed by calling the `execute` method of the model. """ diff --git a/src/hackingBuddyGPT/capabilities/http_request.py b/src/hackingBuddyGPT/capabilities/http_request.py index b7505d2..d89f12b 100644 --- a/src/hackingBuddyGPT/capabilities/http_request.py +++ b/src/hackingBuddyGPT/capabilities/http_request.py @@ -70,7 +70,7 @@ def __call__( url = self.host + ("" if path.startswith("/") else "/") + path + ("?{query}" if query else "") return f"Could not request '{url}': {e}" - headers = "\r\n".join(f"{k}: {v}" for k, v in resp.headers.items()) + response_headers = "\r\n".join(f"{k}: {v}" for k, v in resp.headers.items()) # turn the response into "plain text format" for responding to the prompt - return f"HTTP/1.1 {resp.status_code} {resp.reason}\r\n{headers}\r\n\r\n{resp.text}" + return f"HTTP/1.1 {resp.status_code} {resp.reason}\r\n{response_headers}\r\n\r\n{resp.text}" diff --git a/src/hackingBuddyGPT/capabilities/submit_http_method.py b/src/hackingBuddyGPT/capabilities/submit_http_method.py index ab3040f..79ebb25 100644 --- a/src/hackingBuddyGPT/capabilities/submit_http_method.py +++ b/src/hackingBuddyGPT/capabilities/submit_http_method.py @@ -1,10 +1,8 @@ import base64 -import inspect from dataclasses import dataclass, field from typing import Callable, Dict, Literal, Optional, Set import requests -from pydantic import BaseModel, create_model from . import Capability @@ -23,25 +21,6 @@ class SubmitHTTPMethod(Capability): def describe(self) -> str: return f"Submits a HTTP methods, which is the goal of all activities. The HTTP methods format is {self.http_methods_format}. If you find a flag following the that format, that you have not yet submitted, then the most important first step is to submit it, ignoring all other possibilities of further action" - def to_model(self) -> BaseModel: - """ - Converts the parameters of the `__call__` function of the capability to a pydantic model, that can be used to - interface with an LLM using eg instructor or the openAI function calling API. - The model will have the same name as the capability class and will have the same fields as the `__call__`, - the `__call__` method can then be accessed by calling the `execute` method of the model. - """ - sig = inspect.signature(self.__call__) - fields = {param: (param_info.annotation, ...) for param, param_info in sig.parameters.items()} - model_type = create_model(self.__class__.__name__, __doc__=self.describe(), **fields) - - def execute(model): - m = model.dict() - return self(**m) - - model_type.execute = execute - - return model_type - def __call__( self, method: Literal["GET", "HEAD", "POST", "PUT", "DELETE", "OPTION", "PATCH"], diff --git a/src/hackingBuddyGPT/cli/stats.py b/src/hackingBuddyGPT/cli/stats.py deleted file mode 100755 index 6dabaa6..0000000 --- a/src/hackingBuddyGPT/cli/stats.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/python3 - -import argparse - -from rich.console import Console -from rich.table import Table -from utils.db_storage import DbStorage - -# setup infrastructure for outputing information -console = Console() - -parser = argparse.ArgumentParser(description="View an existing log file.") -parser.add_argument("log", type=str, help="sqlite3 db for reading log data") -args = parser.parse_args() -console.log(args) - -# setup in-memory/persistent storage for command history -db = DbStorage(args.log) -db.connect() -db.setup_db() - -# experiment names -names = { - "1": "suid-gtfo", - "2": "sudo-all", - "3": "sudo-gtfo", - "4": "docker", - "5": "cron-script", - "6": "pw-reuse", - "7": "pw-root", - "8": "vacation", - "9": "ps-bash-hist", - "10": "cron-wildcard", - "11": "ssh-key", - "12": "cron-script-vis", - "13": "cron-wildcard-vis", -} - -# prepare table -table = Table(title="Round Data", show_header=True, show_lines=True) -table.add_column("RunId", style="dim") -table.add_column("Description", style="dim") -table.add_column("Round", style="dim") -table.add_column("State") -table.add_column("Last Command") - -data = db.get_log_overview() -for run in data: - row = data[run] - table.add_row(str(run), names[str(run)], str(row["max_round"]), row["state"], row["last_cmd"]) - -console.print(table) diff --git a/src/hackingBuddyGPT/cli/viewer.py b/src/hackingBuddyGPT/cli/viewer.py deleted file mode 100755 index 4938cb5..0000000 --- a/src/hackingBuddyGPT/cli/viewer.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/python3 - -import argparse - -from rich.console import Console -from rich.panel import Panel -from rich.table import Table -from utils.db_storage import DbStorage - - -# helper to fill the history table with data from the db -def get_history_table(run_id: int, db: DbStorage, round: int) -> Table: - table = Table(title="Executed Command History", show_header=True, show_lines=True) - table.add_column("ThinkTime", style="dim") - table.add_column("Tokens", style="dim") - table.add_column("Cmd") - table.add_column("Resp. Size", justify="right") - # if config.enable_explanation: - # table.add_column("Explanation") - # table.add_column("ExplTime", style="dim") - # table.add_column("ExplTokens", style="dim") - # if config.enable_update_state: - # table.add_column("StateUpdTime", style="dim") - # table.add_column("StateUpdTokens", style="dim") - - for i in range(0, round + 1): - table.add_row(*db.get_round_data(run_id, i, explanation=False, status_update=False)) - # , config.enable_explanation, config.enable_update_state)) - - return table - - -# setup infrastructure for outputing information -console = Console() - -parser = argparse.ArgumentParser(description="View an existing log file.") -parser.add_argument("log", type=str, help="sqlite3 db for reading log data") -args = parser.parse_args() -console.log(args) - -# setup in-memory/persistent storage for command history -db = DbStorage(args.log) -db.connect() -db.setup_db() - -# setup round meta-data -run_id: int = 1 -round: int = 0 - -# read run data - -run = db.get_run_data(run_id) -while run is not None: - if run[4] is None: - console.print(Panel(f"run: {run[0]}/{run[1]}\ntest: {run[2]}\nresult: {run[3]}", title="Run Data")) - else: - console.print( - Panel( - f"run: {run[0]}/{run[1]}\ntest: {run[2]}\nresult: {run[3]} after {run[4]} rounds", - title="Run Data", - ) - ) - console.log(run[5]) - - # Output Round Data - console.print(get_history_table(run_id, db, run[4] - 1)) - - # fetch next run - run_id += 1 - run = db.get_run_data(run_id) diff --git a/src/hackingBuddyGPT/cli/wintermute.py b/src/hackingBuddyGPT/cli/wintermute.py index 91f865b..7ef4c19 100644 --- a/src/hackingBuddyGPT/cli/wintermute.py +++ b/src/hackingBuddyGPT/cli/wintermute.py @@ -11,8 +11,9 @@ def main(): use_case.build_parser(subparser.add_parser(name=name, help=use_case.description)) parsed = parser.parse_args(sys.argv[1:]) + configuration = {k: v for k, v in vars(parsed).items() if k not in ("use_case", "parser_state")} instance = parsed.use_case(parsed) - instance.init() + instance.init(configuration=configuration) instance.run() diff --git a/src/hackingBuddyGPT/resources/webui/static/client.js b/src/hackingBuddyGPT/resources/webui/static/client.js new file mode 100644 index 0000000..2f92daa --- /dev/null +++ b/src/hackingBuddyGPT/resources/webui/static/client.js @@ -0,0 +1,373 @@ +/* jshint esversion: 9, browser: true */ +/* global console */ + +(function() { + "use strict"; + + function debounce(func, wait = 100, immediate = false) { + let timeout; + return function () { + const context = this, + args = arguments; + const later = function () { + timeout = null; + if (!immediate) { + func.apply(context, args); + } + }; + const callNow = immediate && !timeout; + clearTimeout(timeout); + timeout = setTimeout(later, wait); + if (callNow) { + func.apply(context, args); + } + }; + } + + function isScrollAtBottom() { + const content = document.getElementById("main-body"); + console.log( + "scroll check", + content.scrollHeight, + content.scrollTop, + content.clientHeight, + ); + return content.scrollHeight - content.scrollTop <= content.clientHeight + 30; + } + + function scrollUpdate(wasAtBottom) { + const content = document.getElementById("main-body"); + if (wasAtBottom) { + console.log("scrolling to bottom"); + content.scrollTop = content.scrollHeight; + } + } + + const sidebar = document.getElementById("sidebar"); + const menuToggles = document.getElementsByClassName("menu-toggle"); + Array.from(menuToggles).forEach((menuToggle) => { + menuToggle.addEventListener("click", () => { + sidebar.classList.toggle("active"); + }); + }); + + let ws = null; + let currentRun = null; + + const followNewRunsCheckbox = document.getElementById("follow_new_runs"); + let followNewRuns = + !window.location.hash && localStorage.getItem("followNewRuns") === "true"; + followNewRunsCheckbox.checked = followNewRuns; + + followNewRunsCheckbox.addEventListener("change", () => { + followNewRuns = followNewRunsCheckbox.checked; + localStorage.setItem("followNewRuns", followNewRuns); + }); + + let send = function (type, data) { + const message = {type: type, data: data}; + console.log("> sending ", message); + ws.send(JSON.stringify(message)); + }; + + function initWebsocket() { + console.log("initializing websocket"); + ws = new WebSocket( + `ws${location.protocol === "https:" ? "s" : ""}://${location.host}/client`, + ); + + let runs = {}; + + ws.addEventListener("open", () => { + ws.addEventListener("message", (event) => { + const message = JSON.parse(event.data); + console.log("< receiving", message); + const {type, data} = message; + + const wasAtBottom = isScrollAtBottom(); + switch (type) { + case "Run": + handleRunMessage(data); + break; + case "Section": + handleSectionMessage(data); + break; + case "Message": + handleMessage(data); + break; + case "MessageStreamPart": + handleMessageStreamPart(data); + break; + case "ToolCall": + handleToolCall(data); + break; + case "ToolCallStreamPart": + handleToolCallStreamPart(data); + break; + default: + console.warn("Unknown message type:", type); + } + scrollUpdate(wasAtBottom); + }); + + function createRunListEntry(runId) { + const runList = document.getElementById("run-list"); + const template = document.getElementById("run-list-entry-template"); + const runListEntry = template.content + .cloneNode(true) + .querySelector(".run-list-entry"); + runListEntry.id = `run-list-entry-${runId}`; + const a = runListEntry.querySelector("a"); + a.href = "#" + runId; + a.addEventListener("click", () => { + selectRun(runId); + }); + runList.insertBefore(runListEntry, runList.firstChild); + return runListEntry; + } + + function handleRunMessage(run) { + runs[run.id] = run; + let li = document.getElementById(`run-list-entry-${run.id}`); + if (!li) { + li = createRunListEntry(run.id); + } + + li.querySelector(".run-id").textContent = `Run ${run.id}`; + li.querySelector(".run-model").tExtContent = run.model; + li.querySelector(".run-tags").textContent = run.tag; + li.querySelector(".run-started-at").textContent = run.started_at.slice( + 0, + -3, + ); + if (run.stopped_at) { + li.querySelector(".run-stopped-at").textContent = run.stopped_at.slice( + 0, + -3, + ); + } + li.querySelector(".run-state").textContent = run.state; + + const followNewRunsCheckbox = document.getElementById("follow_new_runs"); + if (followNewRunsCheckbox.checked) { + selectRun(run.id); + } + } + + function addSectionDiv(sectionId) { + const messagesDiv = document.getElementById("messages"); + const template = document.getElementById("section-template"); + const sectionDiv = template.content + .cloneNode(true) + .querySelector(".section"); + sectionDiv.id = `section-${sectionId}`; + messagesDiv.appendChild(sectionDiv); + return sectionDiv; + } + + let sectionColumns = []; + + function handleSectionMessage(section) { + console.log("handling section message", section); + section.from_message += 1; + if (section.to_message === null) { + section.to_message = 99999; + } + section.to_message += 1; + + let sectionDiv = document.getElementById(`section-${section.id}`); + if (!!sectionDiv) { + let columnNumber = sectionDiv.getAttribute("columnNumber"); + let columnPosition = sectionDiv.getAttribute("columnPosition"); + sectionColumns[columnNumber].splice(columnPosition - 1, 1); + sectionDiv.remove(); + } + sectionDiv = addSectionDiv(section.id); + sectionDiv.querySelector(".section-name").textContent = + `${section.name} ${section.duration.toFixed(3)}s`; + + let columnNumber = 0; + let columnPosition = 0; + + // loop over the existing section Columns (format is a list of lists, whereby the inner list is [from_message, from_message], with end_message possibly being None) + let found = false; + for (let i = 0; i < sectionColumns.length; i++) { + const column = sectionColumns[i]; + let columnFits = true; + for (let j = 0; j < column.length; j++) { + const [from_message, to_message] = column[j]; + if ( + section.from_message < to_message && + from_message < section.to_message + ) { + columnFits = false; + break; + } + } + if (!columnFits) { + continue; + } + + column.push([section.from_message, section.to_message]); + columnNumber = i; + columnPosition = column.length; + found = true; + break; + } + if (!found) { + sectionColumns.push([[section.from_message, section.to_message]]); + document.documentElement.style.setProperty( + "--section-column-count", + sectionColumns.length, + ); + console.log( + "added section column", + sectionColumns.length, + sectionColumns, + ); + } + + sectionDiv.style = `grid-column: ${columnNumber}; grid-row: ${section.from_message} / ${section.to_message};`; + sectionDiv.setAttribute("columnNumber", columnNumber); + sectionDiv.setAttribute("columnPosition", columnPosition); + } + + function addMessageDiv(messageId, role) { + const messagesDiv = document.getElementById("messages"); + const template = document.getElementById("message-template"); + const messageDiv = template.content + .cloneNode(true) + .querySelector(".message"); + messageDiv.id = `message-${messageId}`; + messageDiv.style = `grid-row: ${messageId + 1};`; + if (role === "system") { + messageDiv.removeAttribute("open"); + } + messageDiv.querySelector(".tool-calls").id = + `message-${messageId}-tool-calls`; + messagesDiv.appendChild(messageDiv); + return messageDiv; + } + + function handleMessage(message) { + let messageDiv = document.getElementById(`message-${message.id}`); + if (!messageDiv) { + messageDiv = addMessageDiv(message.id, message.role); + } + if (message.content && message.content.length > 0) { + messageDiv.getElementsByTagName("pre")[0].textContent = message.content; + } + messageDiv.querySelector(".role").textContent = message.role; + messageDiv.querySelector(".duration").textContent = + `${message.duration.toFixed(3)} s`; + messageDiv.querySelector(".tokens-query").textContent = + `${message.tokens_query} qry tokens`; + messageDiv.querySelector(".tokens-response").textContent = + `${message.tokens_response} rsp tokens`; + } + + function handleMessageStreamPart(part) { + let messageDiv = document.getElementById(`message-${part.message_id}`); + if (!messageDiv) { + messageDiv = addMessageDiv(part.message_id); + } + messageDiv.getElementsByTagName("pre")[0].textContent += part.content; + } + + function addToolCallDiv(messageId, toolCallId, functionName) { + const toolCallsDiv = document.getElementById( + `message-${messageId}-tool-calls`, + ); + const template = document.getElementById("message-tool-call"); + const toolCallDiv = template.content + .cloneNode(true) + .querySelector(".tool-call"); + toolCallDiv.id = `message-${messageId}-tool-call-${toolCallId}`; + toolCallDiv.querySelector(".tool-call-function").textContent = + functionName; + toolCallsDiv.appendChild(toolCallDiv); + return toolCallDiv; + } + + function handleToolCall(toolCall) { + let toolCallDiv = document.getElementById( + `message-${toolCall.message_id}-tool-call-${toolCall.id}`, + ); + if (!toolCallDiv) { + toolCallDiv = addToolCallDiv( + toolCall.message_id, + toolCall.id, + toolCall.function_name, + ); + } + toolCallDiv.querySelector(".tool-call-state").textContent = + toolCall.state; + toolCallDiv.querySelector(".tool-call-duration").textContent = + `${toolCall.duration.toFixed(3)} s`; + toolCallDiv.querySelector(".tool-call-parameters").textContent = + toolCall.arguments; + toolCallDiv.querySelector(".tool-call-results").textContent = + toolCall.result_text; + } + + function handleToolCallStreamPart(part) { + const messageDiv = document.getElementById( + `message-${part.message_id}-tool-calls`, + ); + if (messageDiv) { + let toolCallDiv = messageDiv.querySelector( + `.tool-call-${part.tool_call_id}`, + ); + if (!toolCallDiv) { + toolCallDiv = document.createElement("div"); + toolCallDiv.className = `tool-call tool-call-${part.tool_call_id}`; + messageDiv.appendChild(toolCallDiv); + } + toolCallDiv.textContent += part.content; + } + } + + const selectRun = debounce((runId) => { + console.error("selectRun", runId, currentRun); + if (runId === currentRun) { + return; + } + + document.getElementById("messages").innerHTML = ""; + sectionColumns = []; + document.documentElement.style.setProperty("--section-column-count", 0); + send("MessageRequest", {follow_run: runId}); + currentRun = runId; + // set hash to runId via pushState + window.location.hash = runId; + sidebar.classList.remove("active"); + document.getElementById("main-run-title").textContent = `Run ${runId}`; + + // try to json parse and pretty print the run configuration into `#run-config` + try { + const config = JSON.parse(runs[runId].configuration); + document.getElementById("run-config").textContent = JSON.stringify( + config, + null, + 2, + ); + } catch (e) { + document.getElementById("run-config").textContent = + runs[runId].configuration; + } + }); + if (window.location.hash) { + selectRun(parseInt(window.location.hash.slice(1), 10)); + } else { + // toggle the sidebar if no run is selected + sidebar.classList.add("active"); + document.getElementById("main-run-title").textContent = + "Please select a run"; + } + + ws.addEventListener("close", initWebsocket); + }); + } + + initWebsocket(); +})(); \ No newline at end of file diff --git a/src/hackingBuddyGPT/resources/webui/static/favicon.ico b/src/hackingBuddyGPT/resources/webui/static/favicon.ico new file mode 100644 index 0000000..474dae3 Binary files /dev/null and b/src/hackingBuddyGPT/resources/webui/static/favicon.ico differ diff --git a/src/hackingBuddyGPT/resources/webui/static/style.css b/src/hackingBuddyGPT/resources/webui/static/style.css new file mode 100644 index 0000000..de021c0 --- /dev/null +++ b/src/hackingBuddyGPT/resources/webui/static/style.css @@ -0,0 +1,365 @@ +/* Reset default margin and padding */ +:root { + --section-count: 0; + --section-column-count: 0; +} + +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: Arial, sans-serif; +} + +pre { + white-space: pre-wrap; +} + +pre.binary { + white-space: break-spaces; + word-break: break-all; + word-wrap: anywhere; + overflow-wrap: anywhere; + -webkit-hyphens: auto; + hyphens: auto; + -webkit-line-break: after-white-space; +} + +details summary { + list-style: none; + cursor: pointer; +} +details summary::-webkit-details-marker { + display: none; +} + +.container { + display: grid; + grid-template-columns: 250px 1fr; + height: 100vh; + overflow: hidden; +} + +/* Sidebar styling */ +.sidebar { + background-color: #333; + color: white; + padding: 0 1rem 1rem; + height: 100%; + overflow: scroll; + z-index: 100; +} + +.sidebar ul { + list-style: none; + padding: 0; +} + +.sidebar li { + margin-bottom: 1rem; +} + +.sidebar a { + color: white; + text-decoration: none; +} + +.sidebar a:hover { + text-decoration: underline; +} + +.sidebar #run-list { + margin-top: 6.5rem; + padding-top: 1rem; +} + +.sidebar .run-list-entry a { + display: flex; + flex-direction: row; + justify-content: space-between; + align-items: center; + width: 100%; +} + +.sidebar .run-list-entry a > div { + display: flex; + flex-direction: column; +} + +.sidebar .run-list-info { + flex-grow: 1; +} + +.sidebar .run-list-info span { + color: lightgray; + font-size: small; +} + +.sidebar .run-list-timing { + flex-shrink: 0; + font-size: small; + color: lightgray; +} + +#follow-new-runs-container { + margin: 1.5rem 1rem 1rem; +} + +/* Main content styling */ +#main-body { + background-color: #f4f4f4; + height: 100%; + overflow: auto; +} + +#sidebar-header-container { + margin-left: -1rem; + height: 6.5rem; + display: flex; + flex-direction: column; + justify-content: start; + position: fixed; + background-color: #333; +} + +#sidebar-header, +#run-header { + display: flex; + flex-direction: row; + height: 3rem; + align-items: center; +} + +#run-header { + position: fixed; + background-color: #f4f4f4; + z-index: 50; + width: 100%; + border-top: 4px solid #333; + border-bottom: 4px solid #333; +} + +#black-block { + position: fixed; + height: 6.5rem; + width: calc(2rem + var(--section-column-count) * 1rem); + background-color: #333; + z-index: 25; +} + +#run-header .menu-toggle { + background-color: #333; + color: #333; + width: 6rem; + height: 3rem; +} + +#run-header #main-run-title { + display: inline-block; + flex-grow: 1; +} + +#sidebar-header .menu-toggle { + background-color: #333; + color: #f4f4f4; + width: 3rem; + height: 3rem; +} +.menu-toggle { + background: none; + border: none; + font-size: 24px; + line-height: 22px; + margin-right: 0.5rem; + color: white; +} + +.small { + font-size: small; +} + +#run-config-details { + padding-top: 3rem; + border-left: calc(2rem + var(--section-column-count) * 1rem) solid #333; +} + +#run-config-details summary { + /*background-color: #333; + color: white;*/ + padding: 0.3rem 0.3rem 0.3rem 1rem; + height: 3.5rem; + display: flex; + align-items: center; +} + +#run-config-details pre { + margin: 0 1rem; + padding-bottom: 1rem; +} + +#messages { + margin: 0 1rem 1rem; + display: grid; + /* this 1000 is a little bit of a hack, as other methods for auto sizing don't seem to work. Keep this one less than the number used as grid-column in .message */ + grid-template-columns: repeat(1000, min-content) 1fr; + grid-auto-rows: auto; + grid-gap: 0; +} + +.section { + display: flex; + flex-direction: column; + align-items: center; + position: relative; + width: 1rem; + justify-self: center; +} + +.section .line { + width: 4px; + background: black; + min-height: 0.2rem; + flex-grow: 1; +} + +.section .end-line { + margin-bottom: 1rem; +} + +.section span { + transform: rotate(-90deg); + padding: 0 4px; + margin: 5px 0; + white-space: nowrap; + background-color: #f4f4f4; +} + +.message { + /* this 1000 is a little bit of a hack, as other methods for auto sizing don't seem to work. Keep this one more than the number used in grid-template-columns in .messages */ + grid-column: calc(1001); + margin-left: 1rem; + margin-bottom: 1rem; + background-color: #f9f9f9; + border-left: 4px solid #333; +} + +/* this applies to both the message header as well as the individual tool calls */ +.message header { + background-color: #333; + color: white; + padding: 0.5rem; + display: flex; +} + +.message .tool-call header { + flex-direction: row; + justify-content: space-between; +} + +.message .message-header { + flex-direction: column; +} +.message .message-header > div { + display: flex; + flex-direction: row; + justify-content: space-between; +} + +.message .message-text { + margin: 1rem; +} + +.message .tool-calls { + margin: 1rem; + display: flex; + flex-direction: row; + flex-wrap: wrap; + gap: 1rem; +} + +.message .tool-call { + border: 2px solid #333; + border-radius: 4px; + padding-top: 0; + height: 100%; + width: 100%; +} + +.message .tool-call-parameters { + border-left: 4px solid lightgreen; + padding: 1rem 0.5rem; +} + +.message .tool-call-results { + border-left: 4px solid lightcoral; + padding: 1rem 0.5rem; +} + +/* Responsive behavior */ +@media (max-width: 1468px) { + .container { + grid-template-columns: 1fr; + } + + .sidebar { + position: absolute; + width: 100vw; + height: 100%; + top: 0; + left: -100vw; /* Hidden off-screen by default */ + transition: left 0.3s ease; + } + + #main-body { + grid-column: span 2; + } + + #sidebar-header .menu-toggle, + #run-header .menu-toggle { + display: inline-block; + cursor: pointer; + } + + /* Show the sidebar when toggled */ + .sidebar.active { + left: 0; + } + + #messages, + .message { + margin-left: 0.5rem; + margin-right: 0; + } + #run-header .menu-toggle { + width: 4rem; + color: white; + } + #run-config-details { + border-left: calc(1rem + var(--section-column-count) * 1rem) solid #333; + } + #black-block { + width: calc(1rem + var(--section-column-count) * 1rem); + } + + #sidebar-header-container { + width: 100%; + } + #sidebar-header .menu-toggle { + color: black; + background-color: #f4f4f4; + } + #sidebar-header { + border-top: 4px solid #f4f4f4; + border-bottom: 4px solid #f4f4f4; + width: 100%; + } + .sidebar #run-list { + margin-left: 2.5rem; + } + #follow-new-runs-container { + margin-left: 3.5rem; + } +} diff --git a/src/hackingBuddyGPT/resources/webui/templates/index.html b/src/hackingBuddyGPT/resources/webui/templates/index.html new file mode 100644 index 0000000..6a8475d --- /dev/null +++ b/src/hackingBuddyGPT/resources/webui/templates/index.html @@ -0,0 +1,96 @@ + + + + + + + hackingBuddyGPT + + +
+ + +
+
+ +

+
+
+
+ +

Configuration

+
+

+                
+
+
+
+ + + + + + + diff --git a/src/hackingBuddyGPT/usecases/__init__.py b/src/hackingBuddyGPT/usecases/__init__.py index a3a34c6..cc294ff 100644 --- a/src/hackingBuddyGPT/usecases/__init__.py +++ b/src/hackingBuddyGPT/usecases/__init__.py @@ -2,3 +2,4 @@ from .privesc import * from .web import * from .web_api_testing import * +from .viewer import * diff --git a/src/hackingBuddyGPT/usecases/agents.py b/src/hackingBuddyGPT/usecases/agents.py index 7497443..053b522 100644 --- a/src/hackingBuddyGPT/usecases/agents.py +++ b/src/hackingBuddyGPT/usecases/agents.py @@ -1,24 +1,24 @@ +import datetime from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Dict - from mako.template import Template -from rich.panel import Panel +from typing import Dict +from hackingBuddyGPT.utils.logging import log_conversation, GlobalLogger from hackingBuddyGPT.capabilities.capability import ( Capability, capabilities_to_simple_text_handler, ) -from hackingBuddyGPT.usecases.base import Logger from hackingBuddyGPT.utils import llm_util from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection @dataclass class Agent(ABC): + log: GlobalLogger = None + _capabilities: Dict[str, Capability] = field(default_factory=dict) _default_capability: Capability = None - _log: Logger = None llm: OpenAIConnection = None @@ -36,14 +36,49 @@ def after_run(self): # noqa: B027 def perform_round(self, turn: int) -> bool: pass - def add_capability(self, cap: Capability, default: bool = False): - self._capabilities[cap.get_name()] = cap + def add_capability(self, cap: Capability, name: str = None, default: bool = False): + if name is None: + name = cap.get_name() + self._capabilities[name] = cap if default: self._default_capability = cap def get_capability(self, name: str) -> Capability: return self._capabilities.get(name, self._default_capability) + def run_capability_json(self, message_id: int, tool_call_id: str, capability_name: str, arguments: str) -> str: + capability = self.get_capability(capability_name) + + tic = datetime.datetime.now() + try: + result = capability.to_model().model_validate_json(arguments).execute() + except Exception as e: + result = f"EXCEPTION: {e}" + duration = datetime.datetime.now() - tic + + self.log.add_tool_call(message_id, tool_call_id, capability_name, arguments, result, duration) + return result + + def run_capability_simple_text(self, message_id: int, cmd: str) -> tuple[str, str, str, bool]: + _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) + + tic = datetime.datetime.now() + try: + success, output = parser(cmd) + except Exception as e: + success = False + output = f"EXCEPTION: {e}" + duration = datetime.datetime.now() - tic + + if not success: + self.log.add_tool_call(message_id, tool_call_id=0, function_name="", arguments=cmd, result_text=output[0], duration=0) + return "", "", output, False + + capability, cmd, (result, got_root) = output + self.log.add_tool_call(message_id, tool_call_id=0, function_name=capability, arguments=cmd, result_text=result, duration=duration) + + return capability, cmd, result, got_root + def get_capability_block(self) -> str: capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities) return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values()) @@ -75,28 +110,15 @@ def set_template(self, template: str): self._template = Template(filename=template) self._template_size = self.llm.count_tokens(self._template.source) + @log_conversation("Asking LLM for a new command...") def perform_round(self, turn: int) -> bool: - got_root: bool = False - - with self._log.console.status("[bold green]Asking LLM for a new command..."): - # TODO output/log state - options = self._state.to_template() - options.update({"capabilities": self.get_capability_block()}) - - # get the next command from the LLM - answer = self.llm.get_response(self._template, **options) - cmd = llm_util.cmd_output_fixer(answer.result) + # get the next command from the LLM + answer = self.llm.get_response(self._template, capabilities=self.get_capability_block(), **self._state.to_template()) + message_id = self.log.call_response(answer) - with self._log.console.status("[bold green]Executing that command..."): - self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - capability = self.get_capability(cmd.split(" ", 1)[0]) - result, got_root = capability(cmd) + capability, cmd, result, got_root = self.run_capability_simple_text(message_id, llm_util.cmd_output_fixer(answer.result)) - # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) self._state.update(capability, cmd, result) - # TODO output/log new state - self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) # if we got root, we can stop the loop return got_root diff --git a/src/hackingBuddyGPT/usecases/base.py b/src/hackingBuddyGPT/usecases/base.py index 10cd3bf..21f9182 100644 --- a/src/hackingBuddyGPT/usecases/base.py +++ b/src/hackingBuddyGPT/usecases/base.py @@ -1,28 +1,13 @@ import abc +import json import argparse -import typing from dataclasses import dataclass -from typing import Dict, Type -from rich.panel import Panel +from hackingBuddyGPT.utils.logging import GlobalLogger +from typing import Dict, Type, TypeVar, Generic -from hackingBuddyGPT.utils.configurable import ( - ParameterDefinitions, - build_parser, - get_arguments, - get_class_parameters, - transparent, -) -from hackingBuddyGPT.utils.console.console import Console -from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage - - -@dataclass -class Logger: - log_db: DbStorage - console: Console - tag: str = "" - run_id: int = 0 +from hackingBuddyGPT.utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters, \ + Transparent, ParserState @dataclass @@ -37,21 +22,19 @@ class UseCase(abc.ABC): so that they can be automatically discovered and run from the command line. """ - log_db: DbStorage - console: Console - tag: str = "" + log: GlobalLogger - _run_id: int = 0 - _log: Logger = None - - def init(self): + def init(self, configuration): """ The init method is called before the run method. It is used to initialize the UseCase, and can be used to perform any dynamic setup that is needed before the run method is called. One of the most common use cases is setting up the llm capabilities from the tools that were injected. """ - self._run_id = self.log_db.create_new_run(self.get_name(), self.tag) - self._log = Logger(self.log_db, self.console, self.tag, self._run_id) + self.configuration = configuration + self.log.start_run(self.get_name(), self.serialize_configuration(configuration)) + + def serialize_configuration(self, configuration) -> str: + return json.dumps(configuration) @abc.abstractmethod def run(self): @@ -91,26 +74,28 @@ def run(self): self.before_run() turn = 1 - while turn <= self.max_turns and not self._got_root: - self._log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}") + try: + while turn <= self.max_turns and not self._got_root: + with self.log.section(f"round {turn}"): + self.log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}") - self._got_root = self.perform_round(turn) + self._got_root = self.perform_round(turn) - # finish turn and commit logs to storage - self._log.log_db.commit() - turn += 1 + turn += 1 - self.after_run() + self.after_run() - # write the final result to the database and console - if self._got_root: - self._log.log_db.run_was_success(self._run_id, turn) - self._log.console.print(Panel("[bold green]Got Root!", title="Run finished")) - else: - self._log.log_db.run_was_failure(self._run_id, turn) - self._log.console.print(Panel("[green]maximum turn number reached", title="Run finished")) + # write the final result to the database and console + if self._got_root: + self.log.run_was_success() + else: + self.log.run_was_failure("maximum turn number reached") - return self._got_root + return self._got_root + except Exception: + import traceback + self.log.run_was_failure("exception occurred", details=f":\n\n{traceback.format_exc()}") + raise @dataclass @@ -126,20 +111,21 @@ class _WrappedUseCase: parameters: ParameterDefinitions def build_parser(self, parser: argparse.ArgumentParser): - build_parser(self.parameters, parser) - parser.set_defaults(use_case=self) + parser_state = ParserState() + build_parser(self.parameters, parser, parser_state) + parser.set_defaults(use_case=self, parser_state=parser_state) def __call__(self, args: argparse.Namespace): - return self.use_case(**get_arguments(self.parameters, args)) + return self.use_case(**get_arguments(self.parameters, args, args.parser_state)) use_cases: Dict[str, _WrappedUseCase] = dict() -T = typing.TypeVar("T") +T = TypeVar("T") -class AutonomousAgentUseCase(AutonomousUseCase, typing.Generic[T]): +class AutonomousAgentUseCase(AutonomousUseCase, Generic[T]): agent: T = None def perform_round(self, turn: int): @@ -154,11 +140,10 @@ def __class_getitem__(cls, item): item.__parameters__ = get_class_parameters(item) class AutonomousAgentUseCase(AutonomousUseCase): - agent: transparent(item) = None + agent: Transparent(item) = None - def init(self): - super().init() - self.agent._log = self._log + def init(self, configuration): + super().init(configuration) self.agent.init() def get_name(self) -> str: diff --git a/src/hackingBuddyGPT/usecases/examples/agent.py b/src/hackingBuddyGPT/usecases/examples/agent.py index b87b540..337cf38 100644 --- a/src/hackingBuddyGPT/usecases/examples/agent.py +++ b/src/hackingBuddyGPT/usecases/examples/agent.py @@ -1,9 +1,9 @@ import pathlib from mako.template import Template -from rich.panel import Panel from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential +from hackingBuddyGPT.utils.logging import log_conversation from hackingBuddyGPT.usecases.agents import Agent from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case from hackingBuddyGPT.utils import SSHConnection, llm_util @@ -15,40 +15,35 @@ class ExPrivEscLinux(Agent): conn: SSHConnection = None + _sliding_history: SlidingCliHistory = None + _max_history_size: int = 0 def init(self): super().init() + self._sliding_history = SlidingCliHistory(self.llm) + self._max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - self.llm.count_tokens(template_next_cmd.source) + self.add_capability(SSHRunCommand(conn=self.conn), default=True) self.add_capability(SSHTestCredential(conn=self.conn)) - self._template_size = self.llm.count_tokens(template_next_cmd.source) + @log_conversation("Asking LLM for a new command...") def perform_round(self, turn: int) -> bool: - got_root: bool = False - - with self._log.console.status("[bold green]Asking LLM for a new command..."): - # get as much history as fits into the target context size - history = self._sliding_history.get_history( - self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size - ) - - # get the next command from the LLM - answer = self.llm.get_response( - template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn - ) - cmd = llm_util.cmd_output_fixer(answer.result) - - with self._log.console.status("[bold green]Executing that command..."): - self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) - - # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) + # get as much history as fits into the target context size + history = self._sliding_history.get_history(self._max_history_size) + + # get the next command from the LLM + answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) + message_id = self.log.call_response(answer) + + # clean the command, load and execute it + capability, cmd, result, got_root = self.run_capability_simple_text(message_id, llm_util.cmd_output_fixer(answer.result)) + + # store the results in our local history self._sliding_history.add_command(cmd, result) - self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) - # if we got root, we can stop the loop + # signal if we were successful in our task return got_root diff --git a/src/hackingBuddyGPT/usecases/examples/hintfile.py b/src/hackingBuddyGPT/usecases/examples/hintfile.py index 274b4cd..b5b1670 100644 --- a/src/hackingBuddyGPT/usecases/examples/hintfile.py +++ b/src/hackingBuddyGPT/usecases/examples/hintfile.py @@ -8,8 +8,8 @@ class ExPrivEscLinuxHintFileUseCase(AutonomousAgentUseCase[LinuxPrivesc]): hints: str = None - def init(self): - super().init() + def init(self, configuration): + super().init(configuration) self.agent.hint = self.read_hint() # simple helper that reads the hints file and returns the hint @@ -21,7 +21,7 @@ def read_hint(self): if self.agent.conn.hostname in hints: return hints[self.agent.conn.hostname] except FileNotFoundError: - self._log.console.print("[yellow]Hint file not found") + self.log.console.print("[yellow]Hint file not found") except Exception as e: - self._log.console.print("[yellow]Hint file could not loaded:", str(e)) + self.log.console.print("[yellow]Hint file could not loaded:", str(e)) return "" diff --git a/src/hackingBuddyGPT/usecases/examples/lse.py b/src/hackingBuddyGPT/usecases/examples/lse.py index 3e31cd7..cdf135c 100644 --- a/src/hackingBuddyGPT/usecases/examples/lse.py +++ b/src/hackingBuddyGPT/usecases/examples/lse.py @@ -26,20 +26,17 @@ class ExPrivEscLinuxLSEUseCase(UseCase): # use either an use-case or an agent to perform the privesc use_use_case: bool = False - def init(self): - super().init() - # simple helper that uses lse.sh to get hints from the system def call_lse_against_host(self): - self._log.console.print("[green]performing initial enumeration with lse.sh") + self.log.console.print("[green]performing initial enumeration with lse.sh") run_cmd = "wget -q 'https://github.com/diego-treitos/linux-smart-enumeration/releases/latest/download/lse.sh' -O lse.sh;chmod 700 lse.sh; ./lse.sh -c -i -l 0 | grep -v 'nope$' | grep -v 'skip$'" result, _ = SSHRunCommand(conn=self.conn, timeout=120)(run_cmd) - self.console.print("[yellow]got the output: " + result) + self.log.console.print("[yellow]got the output: " + result) cmd = self.llm.get_response(template_lse, lse_output=result, number=3) - self.console.print("[yellow]got the cmd: " + cmd.result) + self.log.console.print("[yellow]got the cmd: " + cmd.result) return [x for x in cmd.result.splitlines() if x.strip()] @@ -54,14 +51,14 @@ def run(self): # now try to escalate privileges using the hints for hint in hints: if self.use_use_case: - self.console.print("[yellow]Calling a use-case to perform the privilege escalation") + self.log.console.print("[yellow]Calling a use-case to perform the privilege escalation") result = self.run_using_usecases(hint, turns_per_hint) else: - self.console.print("[yellow]Calling an agent to perform the privilege escalation") + self.log.console.print("[yellow]Calling an agent to perform the privilege escalation") result = self.run_using_agent(hint, turns_per_hint) if result is True: - self.console.print("[green]Got root!") + self.log.console.print("[green]Got root!") return True def run_using_usecases(self, hint, turns_per_hint): @@ -76,10 +73,9 @@ def run_using_usecases(self, hint, turns_per_hint): hint=hint, ), max_turns=turns_per_hint, - log_db=self.log_db, - console=self.console, + log=self.log, ) - linux_privesc.init() + linux_privesc.init(self.configuration) return linux_privesc.run() def run_using_agent(self, hint, turns_per_hint): @@ -92,7 +88,7 @@ def run_using_agent(self, hint, turns_per_hint): enable_update_state=self.enable_update_state, disable_history=self.disable_history, ) - agent._log = self._log + agent.log = self.log agent.init() # perform the privilege escalation @@ -100,7 +96,7 @@ def run_using_agent(self, hint, turns_per_hint): turn = 1 got_root = False while turn <= turns_per_hint and not got_root: - self._log.console.log(f"[yellow]Starting turn {turn} of {turns_per_hint}") + self.log.console.log(f"[yellow]Starting turn {turn} of {turns_per_hint}") if agent.perform_round(turn) is True: got_root = True diff --git a/src/hackingBuddyGPT/usecases/privesc/common.py b/src/hackingBuddyGPT/usecases/privesc/common.py index 5bf8003..b528565 100644 --- a/src/hackingBuddyGPT/usecases/privesc/common.py +++ b/src/hackingBuddyGPT/usecases/privesc/common.py @@ -1,14 +1,14 @@ +import datetime import pathlib from dataclasses import dataclass, field -from typing import Any, Dict - from mako.template import Template -from rich.panel import Panel +from typing import Any, Dict, Optional from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.capability import capabilities_to_simple_text_handler from hackingBuddyGPT.usecases.agents import Agent -from hackingBuddyGPT.utils import llm_util, ui +from hackingBuddyGPT.utils.logging import log_section, log_conversation +from hackingBuddyGPT.utils import llm_util from hackingBuddyGPT.utils.cli_history import SlidingCliHistory template_dir = pathlib.Path(__file__).parent / "templates" @@ -31,12 +31,9 @@ class Privesc(Agent): _template_params: Dict[str, Any] = field(default_factory=dict) _max_history_size: int = 0 - def init(self): - super().init() - def before_run(self): if self.hint != "": - self._log.console.print(f"[bold green]Using the following hint: '{self.hint}'") + self.log.status_message(f"[bold green]Using the following hint: '{self.hint}'") if self.disable_history is False: self._sliding_history = SlidingCliHistory(self.llm) @@ -54,56 +51,24 @@ def before_run(self): self._max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - template_size def perform_round(self, turn: int) -> bool: - got_root: bool = False - - with self._log.console.status("[bold green]Asking LLM for a new command..."): - answer = self.get_next_command() - cmd = answer.result - - with self._log.console.status("[bold green]Executing that command..."): - self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - _capability_descriptions, parser = capabilities_to_simple_text_handler( - self._capabilities, default_capability=self._default_capability - ) - success, *output = parser(cmd) - if not success: - self._log.console.print(Panel(output[0], title="[bold red]Error parsing command:")) - return False - - assert len(output) == 1 - capability, cmd, (result, got_root) = output[0] + # get the next command and run it + cmd, message_id = self.get_next_command() + result, got_root = self.run_command(cmd, message_id) # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) if self._sliding_history: self._sliding_history.add_command(cmd, result) - self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) - # analyze the result.. if self.enable_explanation: - with self._log.console.status("[bold green]Analyze its result..."): - answer = self.analyze_result(cmd, result) - self._log.log_db.add_log_analyze_response(self._log.run_id, turn, cmd, answer.result, answer) + self.analyze_result(cmd, result) # .. and let our local model update its state if self.enable_update_state: - # this must happen before the table output as we might include the - # status processing time in the table.. - with self._log.console.status("[bold green]Updating fact list.."): - state = self.update_state(cmd, result) - self._log.log_db.add_log_update_state(self._log.run_id, turn, "", state.result, state) - - # Output Round Data.. - self._log.console.print( - ui.get_history_table( - self.enable_explanation, self.enable_update_state, self._log.run_id, self._log.log_db, turn - ) - ) - - # .. and output the updated state - if self.enable_update_state: - self._log.console.print(Panel(self._state, title="What does the LLM Know about the system?")) + self.update_state(cmd, result) + + # Output Round Data.. # TODO: reimplement + # self.log.console.print(ui.get_history_table(self.enable_explanation, self.enable_update_state, self.log.run_id, self.log.log_db, turn)) # if we got root, we can stop the loop return got_root @@ -114,7 +79,8 @@ def get_state_size(self) -> int: else: return 0 - def get_next_command(self) -> llm_util.LLMResult: + @log_conversation("Asking LLM for a new command...", start_section=True) + def get_next_command(self) -> tuple[str, int]: history = "" if not self.disable_history: history = self._sliding_history.get_history(self._max_history_size - self.get_state_size()) @@ -122,17 +88,37 @@ def get_next_command(self) -> llm_util.LLMResult: self._template_params.update({"history": history, "state": self._state}) cmd = self.llm.get_response(template_next_cmd, **self._template_params) - cmd.result = llm_util.cmd_output_fixer(cmd.result) - return cmd + message_id = self.log.call_response(cmd) + + return llm_util.cmd_output_fixer(cmd.result), message_id + + @log_section("Executing that command...") + def run_command(self, cmd, message_id) -> tuple[Optional[str], bool]: + _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) + start_time = datetime.datetime.now() + success, *output = parser(cmd) + if not success: + self.log.add_tool_call(message_id, tool_call_id=0, function_name="", arguments=cmd, result_text=output[0], duration=0) + return output[0], False + + assert len(output) == 1 + capability, cmd, (result, got_root) = output[0] + duration = datetime.datetime.now() - start_time + self.log.add_tool_call(message_id, tool_call_id=0, function_name=capability, arguments=cmd, result_text=result, duration=duration) + + return result, got_root + @log_conversation("Analyze its result...", start_section=True) def analyze_result(self, cmd, result): state_size = self.get_state_size() target_size = self.llm.context_size - llm_util.SAFETY_MARGIN - state_size # ugly, but cut down result to fit context size result = llm_util.trim_result_front(self.llm, target_size, result) - return self.llm.get_response(template_analyze, cmd=cmd, resp=result, facts=self._state) + answer = self.llm.get_response(template_analyze, cmd=cmd, resp=result, facts=self._state) + self.log.call_response(answer) + @log_conversation("Updating fact list..", start_section=True) def update_state(self, cmd, result): # ugly, but cut down result to fit context size # don't do this linearly as this can take too long @@ -141,6 +127,6 @@ def update_state(self, cmd, result): target_size = ctx - llm_util.SAFETY_MARGIN - state_size result = llm_util.trim_result_front(self.llm, target_size, result) - result = self.llm.get_response(template_state, cmd=cmd, resp=result, facts=self._state) - self._state = result.result - return result + state = self.llm.get_response(template_state, cmd=cmd, resp=result, facts=self._state) + self._state = state.result + self.log.call_response(state) diff --git a/src/hackingBuddyGPT/usecases/viewer.py b/src/hackingBuddyGPT/usecases/viewer.py new file mode 100755 index 0000000..08097fb --- /dev/null +++ b/src/hackingBuddyGPT/usecases/viewer.py @@ -0,0 +1,397 @@ +#!/usr/bin/python3 + +import asyncio +import datetime +import json +import os +import random +import string +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from enum import Enum +import time +from typing import Optional, Union + +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect +from fastapi.responses import FileResponse, HTMLResponse +from starlette.staticfiles import StaticFiles +from starlette.templating import Jinja2Templates + +from hackingBuddyGPT.usecases.base import UseCase, use_case +from hackingBuddyGPT.utils.db_storage import DbStorage +from hackingBuddyGPT.utils.db_storage.db_storage import ( + Message, + MessageStreamPart, + Run, + Section, + ToolCall, + ToolCallStreamPart, +) +from dataclasses_json import dataclass_json + +from hackingBuddyGPT.utils.logging import GlobalLocalLogger, GlobalRemoteLogger + +INGRESS_TOKEN = os.environ.get("INGRESS_TOKEN", None) +VIEWER_TOKEN = os.environ.get("VIEWER_TOKEN", random.choices(string.ascii_letters + string.digits, k=32)) + + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + "/" +RESOURCE_DIR = BASE_DIR + "../resources/webui" +TEMPLATE_DIR = RESOURCE_DIR + "/templates" +STATIC_DIR = RESOURCE_DIR + "/static" + + +@dataclass_json +@dataclass(frozen=True) +class MessageRequest: + follow_run: Optional[int] = None + + +MessageData = Union[MessageRequest, Run, Section, Message, MessageStreamPart, ToolCall, ToolCallStreamPart] + + +class MessageType(str, Enum): + MESSAGE_REQUEST = "MessageRequest" + RUN = "Run" + SECTION = "Section" + MESSAGE = "Message" + MESSAGE_STREAM_PART = "MessageStreamPart" + TOOL_CALL = "ToolCall" + TOOL_CALL_STREAM_PART = "ToolCallStreamPart" + + def get_class(self) -> MessageData: + return { + "MessageRequest": MessageRequest, + "Run": Run, + "Section": Section, + "Message": Message, + "MessageStreamPart": MessageStreamPart, + "ToolCall": ToolCall, + "ToolCallStreamPart": ToolCallStreamPart, + }[self.value] + + +@dataclass_json +@dataclass +class ControlMessage: + type: MessageType + data: MessageData + + +@dataclass_json +@dataclass(frozen=True) +class ReplayMessage: + at: datetime.datetime + message: ControlMessage + + +@dataclass +class Client: + websocket: WebSocket + db: DbStorage + + queue: asyncio.Queue[ControlMessage] = field(default_factory=asyncio.Queue) + + current_run = None + follow_new_runs = False + + async def send_message(self, message: ControlMessage) -> None: + await self.websocket.send_text(message.to_json()) + + async def send(self, type: MessageType, message: MessageData) -> None: + await self.send_message(ControlMessage(type, message)) + + async def send_messages(self) -> None: + runs = self.db.get_runs() + for r in runs: + await self.send(MessageType.RUN, r) + + while True: + try: + msg: ControlMessage = await self.queue.get() + data = msg.data + if msg.type == MessageType.MESSAGE_REQUEST: + if data.follow_run is not None: + await self.switch_to_run(data.follow_run) + + elif msg.type == MessageType.RUN: + await self.send_message(msg) + + elif msg.type in MessageType: + if not hasattr(data, "run_id"): + print("msg has no run_id", data) + if self.current_run == data.run_id: + await self.send_message(msg) + + else: + print(f"Unknown message type: {msg.type}") + + except WebSocketDisconnect: + break + + except Exception as e: + print(f"Error sending message: {e}") + raise e + + async def receive_messages(self) -> None: + while True: + try: + msg = await self.websocket.receive_json() + if msg["type"] != MessageType.MESSAGE_REQUEST: + print(f"Unknown message type: {msg['type']}") + continue + + if "data" not in msg: + print("Invalid message") + continue + + data = msg["data"] + + if "follow_run" not in data: + print("Invalid message") + continue + + message = ControlMessage( + type=MessageType.MESSAGE_REQUEST, + data=MessageRequest(int(data["follow_run"])), + ) + # we don't process the message here, as having all message processing done in lockstep in the send_messages + # function means that we don't have to worry about race conditions between reading from the database and + # incoming messages + await self.queue.put(message) + except Exception as e: + print(f"Error receiving message: {e}") + raise e + + async def switch_to_run(self, run_id: int): + self.current_run = run_id + messages = self.db.get_messages_by_run(run_id) + + tool_calls = list(self.db.get_tool_calls_by_run(run_id)) + tool_calls_per_message = dict() + for tc in tool_calls: + if tc.message_id not in tool_calls_per_message: + tool_calls_per_message[tc.message_id] = [] + tool_calls_per_message[tc.message_id].append(tc) + + sections: list[Section] = list(self.db.get_sections_by_run(run_id)) + sections_starting_with_message = dict() + for s in sections: + if s.from_message not in sections_starting_with_message: + sections_starting_with_message[s.from_message] = [] + sections_starting_with_message[s.from_message].append(s) + + for msg in messages: + if msg.id in sections_starting_with_message: + for s in sections_starting_with_message[msg.id]: + await self.send(MessageType.SECTION, s) + sections.remove(s) + await self.send(MessageType.MESSAGE, msg) + if msg.id in tool_calls_per_message: + for tc in tool_calls_per_message[msg.id]: + await self.send(MessageType.TOOL_CALL, tc) + tool_calls.remove(tc) + + for tc in tool_calls: + await self.send(MessageType.TOOL_CALL, tc) + + for s in sections: + await self.send(MessageType.SECTION, s) + + +@use_case("Webserver for (live) log viewing") +class Viewer(UseCase): + """ + TODOs: + - [ ] This server needs to be as async as possible to allow good performance, but the database accesses are not yet, might be an issue? + """ + log: GlobalLocalLogger + log_db: DbStorage + listen_host: str = "127.0.0.1" + listen_port: int = 4444 + save_playback_dir: str = "" + + async def save_message(self, message: ControlMessage): + if not self.save_playback_dir or len(self.save_playback_dir) == 0: + return + + # check if a file with the name of the message run id already exists in the save_playback_dir + # if it does, append the message to the json lines file + # if it doesn't, create a new file with the name of the message run id and write the message to it + if isinstance(message.data, Run): + run_id = message.data.id + elif hasattr(message.data, "run_id"): + run_id = message.data.run_id + else: + raise ValueError("gotten message without run_id", message) + + if not os.path.exists(self.save_playback_dir): + os.makedirs(self.save_playback_dir) + + file_path = os.path.join(self.save_playback_dir, f"{run_id}.jsonl") + with open(file_path, "a") as f: + f.write(ReplayMessage(datetime.datetime.now(), message).to_json() + "\n") + + def run(self): + @asynccontextmanager + async def lifespan(app: FastAPI): + app.state.db = self.log_db + app.state.clients = [] + + yield + + for client in app.state.clients: + await client.websocket.close() + + app = FastAPI(lifespan=lifespan) + + # TODO: re-enable and only allow anything else than localhost when a token is set + """ + app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:4444", "ws://localhost:4444", "wss://pwn.reinsperger.org", "https://pwn.reinsperger.org", "https://dumb-halloween-game.reinsperger.org"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + """ + + templates = Jinja2Templates(directory=TEMPLATE_DIR) + app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + + @app.get('/favicon.ico') + async def favicon(): + return FileResponse(STATIC_DIR + "/favicon.ico", headers={"Cache-Control": "public, max-age=31536000"}) + + @app.get("/", response_class=HTMLResponse) + async def admin_ui(request: Request): + return templates.TemplateResponse("index.html", {"request": request}) + + @app.websocket("/ingress") + async def ingress_endpoint(websocket: WebSocket): + await websocket.accept() + try: + while True: + # Receive messages from the ingress websocket + data = await websocket.receive_json() + message_type = MessageType(data["type"]) + # parse the data according to the message type into the appropriate dataclass + message = message_type.get_class().from_dict(data["data"]) + + if message_type == MessageType.RUN: + if message.id is None: + message.started_at = datetime.datetime.now() + message.id = app.state.db.create_run(message.model, message.tag, message.started_at, message.configuration) + data["data"]["id"] = message.id # set the id also in the raw data, so we can properly serialize it to replays + else: + app.state.db.update_run(message.id, message.model, message.state, message.tag, message.started_at, message.stopped_at, message.configuration) + await websocket.send_text(message.to_json()) + + elif message_type == MessageType.MESSAGE: + app.state.db.add_or_update_message(message.run_id, message.id, message.conversation, message.role, message.content, message.tokens_query, message.tokens_response, message.duration) + + elif message_type == MessageType.MESSAGE_STREAM_PART: + app.state.db.handle_message_update(message.run_id, message.message_id, message.action, message.content) + + elif message_type == MessageType.TOOL_CALL: + app.state.db.add_tool_call(message.run_id, message.message_id, message.id, message.function_name, message.arguments, message.result_text, message.duration) + + elif message_type == MessageType.SECTION: + app.state.db.add_section(message.run_id, message.id, message.name, message.from_message, message.to_message, message.duration) + + else: + print("UNHANDLED ingress", message) + + control_message = ControlMessage(type=message_type, data=message) + await self.save_message(control_message) + for client in app.state.clients: + await client.queue.put(control_message) + + except WebSocketDisconnect as e: + import traceback + traceback.print_exc() + print("Ingress WebSocket disconnected") + + @app.websocket("/client") + async def client_endpoint(websocket: WebSocket): + await websocket.accept() + client = Client(websocket, app.state.db) + app.state.clients.append(client) + + # run the receiving and sending tasks in the background until one of them returns + tasks = () + try: + tasks = ( + asyncio.create_task(client.send_messages()), + asyncio.create_task(client.receive_messages()), + ) + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + except WebSocketDisconnect: + # read the task exceptions, close remaining tasks + for task in tasks: + if task.exception(): + print(task.exception()) + else: + task.cancel() + app.state.clients.remove(client) + print("Egress WebSocket disconnected") + + import uvicorn + uvicorn.run(app, host=self.listen_host, port=self.listen_port) + + def get_name(self) -> str: + return "log_viewer" + + +@use_case("Tool to replay the .jsonl logs generated by the Viewer (not well tested)") +class Replayer(UseCase): + log: GlobalRemoteLogger + replay_file: str + pause_on_message: bool = False + pause_on_tool_calls: bool = False + playback_speed: float = 1.0 + + def get_name(self) -> str: + return "replayer" + + def init(self, configuration): + self.log.init_websocket() # we don't want to automatically start a run here + + def run(self): + recording_start: Optional[datetime.datetime] = None + replay_start: datetime.datetime = datetime.datetime.now() + + print(f"replaying {self.replay_file}") + for line in open(self.replay_file, "r"): + data = json.loads(line) + msg: ReplayMessage = ReplayMessage.from_dict(data) + msg.message.type = MessageType(data["message"]["type"]) + msg.message.data = msg.message.type.get_class().from_dict(data["message"]["data"]) + + if recording_start is None: + if msg.message.type != MessageType.RUN: + raise ValueError("First message must be a RUN message, is", msg.message.type) + recording_start = msg.at + self.log.start_run(msg.message.data.model, msg.message.data.tag, msg.message.data.configuration, msg.at) + + # wait until the message should be sent + sleep_time = ((msg.at - recording_start) / self.playback_speed) - (datetime.datetime.now() - replay_start) + if sleep_time.total_seconds() > 3: + print(msg) + print(f"sleeping for {sleep_time.total_seconds()}s") + time.sleep(max(sleep_time.total_seconds(), 0)) + + if isinstance(msg.message.data, Run): + msg.message.data.id = self.log.run.id + elif hasattr(msg.message.data, "run_id"): + msg.message.data.run_id = self.log.run.id + else: + raise ValueError("Message has no run_id", msg.message.data) + + if self.pause_on_message and msg.message.type == MessageType.MESSAGE \ + or self.pause_on_tool_calls and msg.message.type == MessageType.TOOL_CALL: + input("Paused, press Enter to continue") + replay_start = datetime.datetime.now() - (msg.at - recording_start) + + print("sending") + self.log.send(msg.message.type, msg.message.data) diff --git a/src/hackingBuddyGPT/usecases/web/__init__.py b/src/hackingBuddyGPT/usecases/web/__init__.py index 0d9307b..d09ebd9 100644 --- a/src/hackingBuddyGPT/usecases/web/__init__.py +++ b/src/hackingBuddyGPT/usecases/web/__init__.py @@ -1,2 +1,3 @@ -from .simple import MinimalWebTesting from .with_explanation import WebTestingWithExplanation + +__all__ = ['WebTestingWithExplanation'] diff --git a/src/hackingBuddyGPT/usecases/web/simple.py b/src/hackingBuddyGPT/usecases/web/simple.py deleted file mode 100644 index eb07556..0000000 --- a/src/hackingBuddyGPT/usecases/web/simple.py +++ /dev/null @@ -1,110 +0,0 @@ -import time -from dataclasses import field -from typing import Any, Dict, List, Union - -import pydantic_core -from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam -from rich.panel import Panel - -from hackingBuddyGPT.capabilities import Capability -from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model -from hackingBuddyGPT.capabilities.http_request import HTTPRequest -from hackingBuddyGPT.capabilities.record_note import RecordNote -from hackingBuddyGPT.capabilities.submit_flag import SubmitFlag -from hackingBuddyGPT.usecases.agents import Agent -from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case -from hackingBuddyGPT.utils import LLMResult, tool_message -from hackingBuddyGPT.utils.configurable import parameter -from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib - -Prompt = List[Union[ChatCompletionMessage, ChatCompletionMessageParam]] -Context = Any - - -class MinimalWebTesting(Agent): - llm: OpenAILib - host: str = parameter(desc="The host to test", default="http://localhost") - flag_format_description: str = parameter( - desc="Description of the flag provided to the LLM", - default="a string starting with 'FLAG.' and ending with '.GALF'", - ) - flag_template: str = parameter( - desc="The template of the flags, whereby {flag} is replaced with the flags", - default="FLAG.{flag}.GALF", - ) - flags: str = parameter( - desc="A comma (,) separated list of flags to find", - default="hostname,dir,username,rootfile,secretfile,adminpass", - ) - - _prompt_history: Prompt = field(default_factory=list) - _context: Context = field(default_factory=lambda: {"notes": list()}) - _capabilities: Dict[str, Capability] = field(default_factory=dict) - _all_flags_found: bool = False - - def init(self): - super().init() - self._prompt_history.append( - { - "role": "system", - "content": f"You are a senior penetration tester, having received a new project of testing a website at {self.host}." - f"Your goal is first to explore the website, building up your knowledge about its parts and inner workings, and then try to plan and execute possible exploitations.\n" - f"Make sure to take lots of notes about the pages you visit, so that you can later more easily work through the page.\n" - f"Do not brute-force anything, all information that you might need can be found on the webserver.\n" - f"THE MOST IMPORTANT THING TO DO IS, that if you see a flag ({self.flag_format_description}), you should submit it immediately.", - } - ) - self._context["host"] = self.host - self._capabilities = { - "submit_flag": SubmitFlag( - self.flag_format_description, - set(self.flag_template.format(flag=flag) for flag in self.flags.split(",")), - success_function=self.all_flags_found, - ), - "http_request": HTTPRequest(self.host), - "record_note": RecordNote(self._context["notes"]), - } - - def all_flags_found(self): - self._log.console.print(Panel("All flags found! Congratulations!", title="system")) - self._all_flags_found = True - - def perform_round(self, turn: int): - with self._log.console.status("[bold green]Asking LLM for a new command..."): - prompt = self._prompt_history # TODO: in the future, this should do some context truncation - - tic = time.perf_counter() - response, completion = self.llm.instructor.chat.completions.create_with_completion( - model=self.llm.model, - messages=prompt, - response_model=capabilities_to_action_model(self._capabilities), - ) - toc = time.perf_counter() - - message = completion.choices[0].message - tool_call_id = message.tool_calls[0].id - command = pydantic_core.to_json(response).decode() - self._log.console.print(Panel(command, title="assistant")) - self._prompt_history.append(message) - - answer = LLMResult( - completion.choices[0].message.content, - str(prompt), - completion.choices[0].message.content, - toc - tic, - completion.usage.prompt_tokens, - completion.usage.completion_tokens, - ) - - with self._log.console.status("[bold green]Executing that command..."): - result = response.execute() - self._log.console.print(Panel(result, title="tool")) - self._prompt_history.append(tool_message(result, tool_call_id)) - - self._log.log_db.add_log_query(self._log.run_id, turn, command, result, answer) - return self._all_flags_found - - -@use_case("Minimal implementation of a web testing use case") -class MinimalWebTestingUseCase(AutonomousAgentUseCase[MinimalWebTesting]): - pass diff --git a/src/hackingBuddyGPT/usecases/web/with_explanation.py b/src/hackingBuddyGPT/usecases/web/with_explanation.py index 1559217..5ed7f44 100644 --- a/src/hackingBuddyGPT/usecases/web/with_explanation.py +++ b/src/hackingBuddyGPT/usecases/web/with_explanation.py @@ -1,9 +1,8 @@ -import time from dataclasses import field -from typing import Any, Dict, List, Union +from typing import List, Any, Union, Dict, Iterable, Optional -from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam -from rich.panel import Panel +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage +from openai.types.chat.chat_completion_chunk import ChoiceDelta from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.http_request import HTTPRequest @@ -41,75 +40,51 @@ class WebTestingWithExplanation(Agent): def init(self): super().init() - self._prompt_history.append( - { - "role": "system", - "content": f"You are a senior penetration tester, having received a new project of testing a website at {self.host}." - f"Your goal is first to explore the website, building up your knowledge about its parts and inner workings, and then try to plan and execute possible exploitations.\n" - f"Think step by step and say everything out loud that you might be doing.\n" - f"Give a list of assumptions you base your tool calls on, before you start a tool call, and reason whether they held or not after the call has completed.\n" - f"Do not brute-force anything, all information that you might need can be found on the webserver.\n" - f"THE MOST IMPORTANT THING TO DO IS, that if you see a flag ({self.flag_format_description}), you should submit it immediately.", - } - ) self._context["host"] = self.host - self._capabilities = { - "submit_flag": SubmitFlag( - self.flag_format_description, - set(self.flag_template.format(flag=flag) for flag in self.flags.split(",")), - success_function=self.all_flags_found, - ), - "http_request": HTTPRequest(self.host), - } + self.add_capability(SubmitFlag(self.flag_format_description, set(self.flag_template.format(flag=flag) for flag in self.flags.split(",")), success_function=self.all_flags_found)) + self.add_capability(HTTPRequest(self.host)) + + def before_run(self): + system_message = ( + f"You are a senior penetration tester, having received a new project of testing a website at {self.host}." + f"Your goal is first to explore the website, building up your knowledge about its parts and inner workings, and then try to plan and execute possible exploitations.\n" + f"Start out by trying to get logged in as the admin of the site. DO NOT BRUTE-FORCE CREDENTIALS, try to always keep responses in mind, especially if they contain unusual information or pointers to where to find additional information!\n" + f"Think step by step and say everything out loud that you might be doing.\n" + f"Give a list of assumptions you base your tool calls on, before you start a tool call, and reason whether they held or not after the call has completed. Try to do multiple tool calls at once, you have a limited number of turn, but not tool calls.\n" + f"Do not brute-force anything, all information that you might need can be found on the webserver.\n" + f"THE MOST IMPORTANT THING TO DO IS, that if you see a flag ({self.flag_format_description}), you should submit it immediately." + ) + self._prompt_history.append({ "role": "system", "content": system_message }) + self.log.system_message(system_message) def all_flags_found(self): - self._log.console.print(Panel("All flags found! Congratulations!", title="system")) + self.log.status_message("All flags found! Congratulations!") self._all_flags_found = True def perform_round(self, turn: int): prompt = self._prompt_history # TODO: in the future, this should do some context truncation - result: LLMResult = None - stream = self.llm.stream_response(prompt, self._log.console, capabilities=self._capabilities) - for part in stream: - result = part + result_stream: Iterable[Union[ChoiceDelta, LLMResult]] = self.llm.stream_response(prompt, self.log.console, capabilities=self._capabilities, get_individual_updates=True) + result: Optional[LLMResult] = None + stream_output = self.log.stream_message("assistant") # TODO: do not hardcode the role + for delta in result_stream: + if isinstance(delta, LLMResult): + result = delta + break + if delta.content is not None: + stream_output.append(delta.content) + if result is None: + self.log.error_message("No result from the LLM") + return False + message_id = stream_output.finalize(result.tokens_query, result.tokens_response, result.duration) message: ChatCompletionMessage = result.result - message_id = self._log.log_db.add_log_message( - self._log.run_id, - message.role, - message.content, - result.tokens_query, - result.tokens_response, - result.duration, - ) self._prompt_history.append(result.result) if message.tool_calls is not None: for tool_call in message.tool_calls: - tic = time.perf_counter() - tool_call_result = ( - self._capabilities[tool_call.function.name] - .to_model() - .model_validate_json(tool_call.function.arguments) - .execute() - ) - toc = time.perf_counter() - - self._log.console.print( - f"\n[bold green on gray3]{' '*self._log.console.width}\nTOOL RESPONSE:[/bold green on gray3]" - ) - self._log.console.print(tool_call_result) - self._prompt_history.append(tool_message(tool_call_result, tool_call.id)) - self._log.log_db.add_log_tool_call( - self._log.run_id, - message_id, - tool_call.id, - tool_call.function.name, - tool_call.function.arguments, - tool_call_result, - toc - tic, - ) + tool_result = self.run_capability_json(message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments) + self._prompt_history.append(tool_message(tool_result, tool_call.id)) return self._all_flags_found diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py index d9c39d9..98781cb 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py @@ -163,8 +163,8 @@ def run_documentation(self, turn, move_type): """ prompt = self.prompt_engineer.generate_prompt(turn, move_type) response, completion = self.llm_handler.call_llm(prompt) - self._log, self._prompt_history, self.prompt_engineer = self.documentation_handler.document_response( - completion, response, self._log, self._prompt_history, self.prompt_engineer + self.log, self._prompt_history, self.prompt_engineer = self.documentation_handler.document_response( + completion, response, self.log, self._prompt_history, self.prompt_engineer ) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py index 69d9d6a..6aff026 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py @@ -110,7 +110,7 @@ def all_http_methods_found(self) -> None: Handles the event when all HTTP methods are found. Displays a congratulatory message and sets the _all_http_methods_found flag to True. """ - self._log.console.print(Panel("All HTTP methods found! Congratulations!", title="system")) + self.log.console.print(Panel("All HTTP methods found! Congratulations!", title="system")) self._all_http_methods_found = True def _setup_capabilities(self) -> None: @@ -156,12 +156,12 @@ def _handle_response(self, completion: Any, response: Any, purpose: str) -> None message = completion.choices[0].message tool_call_id: str = message.tool_calls[0].id command: str = pydantic_core.to_json(response).decode() - self._log.console.print(Panel(command, title="assistant")) + self.log.console.print(Panel(command, title="assistant")) self._prompt_history.append(message) - with self._log.console.status("[bold green]Executing that command..."): + with self.log.console.status("[bold green]Executing that command..."): result: Any = response.execute() - self._log.console.print(Panel(result[:30], title="tool")) + self.log.console.print(Panel(result[:30], title="tool")) if not isinstance(result, str): endpoint: str = str(response.action.path).split("/")[1] self._report_handler.write_endpoint_to_report(endpoint) diff --git a/src/hackingBuddyGPT/utils/configurable.py b/src/hackingBuddyGPT/utils/configurable.py index 52f35a5..bf7d083 100644 --- a/src/hackingBuddyGPT/utils/configurable.py +++ b/src/hackingBuddyGPT/utils/configurable.py @@ -3,7 +3,8 @@ import inspect import os from dataclasses import dataclass -from typing import Any, Dict, Type, TypeVar +from types import NoneType +from typing import Any, Dict, Type, TypeVar, Set, Union from dotenv import load_dotenv @@ -43,6 +44,12 @@ def get_default(key, default): ) +@dataclass +class ParserState: + global_parser_definitions: Set[str] = dataclasses.field(default_factory=lambda: set()) + global_configurations: Dict[Type, Dict[str, Any]] = dataclasses.field(default_factory=lambda: dict()) + + @dataclass class ParameterDefinition: """ @@ -54,14 +61,14 @@ class ParameterDefinition: default: Any description: str - def parser(self, name: str, parser: argparse.ArgumentParser): + def parser(self, name: str, parser: argparse.ArgumentParser, parser_state: ParserState): default = get_default(name, self.default) parser.add_argument( f"--{name}", type=self.type, default=default, required=default is None, help=self.description ) - def get(self, name: str, args: argparse.Namespace): + def get(self, name: str, args: argparse.Namespace, parser_state: ParserState): return getattr(args, name) @@ -78,26 +85,47 @@ class ComplexParameterDefinition(ParameterDefinition): """ parameters: ParameterDefinitions + global_parameter: bool transparent: bool = False - def parser(self, basename: str, parser: argparse.ArgumentParser): + def parser(self, basename: str, parser: argparse.ArgumentParser, parser_state: ParserState): + if self.global_parameter and self.name in parser_state.global_parser_definitions: + return + for name, parameter in self.parameters.items(): if isinstance(parameter, dict): - build_parser(parameter, parser, next_name(basename, name, parameter)) + build_parser(parameter, parser, parser_state, next_name(basename, name, parameter)) else: - parameter.parser(next_name(basename, name, parameter), parser) + parameter.parser(next_name(basename, name, parameter), parser, parser_state) + + if self.global_parameter: + parser_state.global_parser_definitions.add(self.name) + + def get(self, name: str, args: argparse.Namespace, parser_state: ParserState): + def make(name, args): + args = get_arguments(self.parameters, args, parser_state, name) + + def create(): + instance = self.type(**args) + if hasattr(instance, "init") and not getattr(self.type, "__transparent__", False): + instance.init() + setattr(instance, "configurable_recreate", create) + return instance + + return create() + + if not self.global_parameter: + return make(name, args) - def get(self, name: str, args: argparse.Namespace): - args = get_arguments(self.parameters, args, name) + if self.type in parser_state.global_configurations and self.name in parser_state.global_configurations[self.type]: + return parser_state.global_configurations[self.type][self.name] - def create(): - instance = self.type(**args) - if hasattr(instance, "init") and not getattr(self.type, "__transparent__", False): - instance.init() - instance.configurable_recreate = create - return instance + instance = make(name, args) + if self.type not in parser_state.global_configurations: + parser_state.global_configurations[self.type] = dict() + parser_state.global_configurations[self.type][self.name] = instance - return create() + return instance def get_class_parameters(cls, name: str = None, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions: @@ -137,17 +165,31 @@ def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = No if field.type is not None: type = field.type + resolution_name = name + resolution_basename = basename + if getattr(type, "__global__", False): + resolution_name = getattr(type, "__global_name__", None) + if resolution_name is None: + resolution_name = name + resolution_basename = resolution_name + + # check if type is an Optional, and then get the actual type + if hasattr(type, "__origin__") and type.__origin__ is Union and len(type.__args__) == 2 and type.__args__[1] is NoneType: + type = type.__args__[0] + default = None + if hasattr(type, "__parameters__"): params[name] = ComplexParameterDefinition( - name, + resolution_name, type, default, description, - get_class_parameters(type, basename), + get_class_parameters(type, resolution_basename), + global_parameter=getattr(type, "__global__", False), transparent=getattr(type, "__transparent__", False), ) elif type in (str, int, float, bool): - params[name] = ParameterDefinition(name, type, default, description) + params[name] = ParameterDefinition(resolution_name, type, default, description) else: raise ValueError( f"Parameter {name} of {basename} must have str, int, bool, or a __parameters__ class as type, not {type}" @@ -156,13 +198,13 @@ def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = No return params -def build_parser(parameters: ParameterDefinitions, parser: argparse.ArgumentParser, basename: str = ""): +def build_parser(parameters: ParameterDefinitions, parser: argparse.ArgumentParser, parser_state: ParserState, basename: str = ""): for name, parameter in parameters.items(): - parameter.parser(next_name(basename, name, parameter), parser) + parameter.parser(next_name(basename, name, parameter), parser, parser_state) -def get_arguments(parameters: ParameterDefinitions, args: argparse.Namespace, basename: str = "") -> Dict[str, Any]: - return {name: parameter.get(next_name(basename, name, parameter), args) for name, parameter in parameters.items()} +def get_arguments(parameters: ParameterDefinitions, args: argparse.Namespace, parser_state: ParserState, basename: str = "") -> Dict[str, Any]: + return {name: parameter.get(next_name(basename, name, parameter), args, parser_state) for name, parameter in parameters.items()} Configurable = Type # TODO: Define type @@ -189,7 +231,17 @@ def inner(cls) -> Configurable: T = TypeVar("T") -def transparent(subclass: T) -> T: +def Global(subclass: T, global_name: str = None) -> T: + class Cloned(subclass): + __global__ = True + __global_name__ = global_name + Cloned.__name__ = subclass.__name__ + Cloned.__qualname__ = subclass.__qualname__ + Cloned.__module__ = subclass.__module__ + return Cloned + + +def Transparent(subclass: T) -> T: """ setting a type to be transparent means, that it will not increase a level in the configuration tree, so if you have the following classes: @@ -209,6 +261,7 @@ def init(self): the configuration will be `--a` and `--b` instead of `--inner.a` and `--inner.b`. A transparent attribute will also not have its init function called automatically, so you will need to do that on your own, as seen in the Outer init. + The function is upper case on purpose, as it is supposed to be used in a Type context """ class Cloned(subclass): diff --git a/src/hackingBuddyGPT/utils/db_storage/db_storage.py b/src/hackingBuddyGPT/utils/db_storage/db_storage.py index 7f47382..b15853b 100644 --- a/src/hackingBuddyGPT/utils/db_storage/db_storage.py +++ b/src/hackingBuddyGPT/utils/db_storage/db_storage.py @@ -1,12 +1,100 @@ +from dataclasses import dataclass, field +from dataclasses_json import config, dataclass_json +import datetime import sqlite3 +from typing import Literal, Optional, Union -from hackingBuddyGPT.utils.configurable import configurable, parameter +from hackingBuddyGPT.utils.configurable import Global, configurable, parameter + + +timedelta_metadata = config(encoder=lambda td: td.total_seconds(), decoder=lambda seconds: datetime.timedelta(seconds=seconds)) +datetime_metadata = config(encoder=lambda dt: dt.isoformat(), decoder=lambda iso: datetime.datetime.fromisoformat(iso)) +optional_datetime_metadata = config(encoder=lambda dt: dt.isoformat() if dt else None, decoder=lambda iso: datetime.datetime.fromisoformat(iso) if iso else None) + + +StreamAction = Literal["append"] + + +@dataclass_json +@dataclass +class Run: + id: int + model: str + state: str + tag: str + started_at: datetime.datetime = field(metadata=datetime_metadata) + stopped_at: Optional[datetime.datetime] = field(metadata=optional_datetime_metadata) + configuration: str + + +@dataclass_json +@dataclass +class Section: + run_id: int + id: int + name: str + from_message: int + to_message: int + duration: datetime.timedelta = field(metadata=timedelta_metadata) + + +@dataclass_json +@dataclass +class Message: + run_id: int + id: int + version: int + conversation: str + role: str + content: str + duration: datetime.timedelta = field(metadata=timedelta_metadata) + tokens_query: int + tokens_response: int + + +@dataclass_json +@dataclass +class MessageStreamPart: + id: int + run_id: int + message_id: int + action: StreamAction + content: str + + +@dataclass_json +@dataclass +class ToolCall: + run_id: int + message_id: int + id: str + version: int + function_name: str + arguments: str + state: str + result_text: str + duration: datetime.timedelta = field(metadata=timedelta_metadata) + + +@dataclass_json +@dataclass +class ToolCallStreamPart: + id: int + run_id: int + message_id: int + tool_call_id: str + field: Literal["arguments", "result"] + action: StreamAction + content: str + + +LogTypes = Union[Run, Section, Message, MessageStreamPart, ToolCall, ToolCallStreamPart] @configurable("db_storage", "Stores the results of the experiments in a SQLite database") -class DbStorage: +class RawDbStorage: def __init__( - self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default=":memory:") + self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default="wintermute.sqlite3") ): self.connection_string = connection_string @@ -15,21 +103,10 @@ def init(self): self.setup_db() def connect(self): - self.db = sqlite3.connect(self.connection_string) + self.db = sqlite3.connect(self.connection_string, isolation_level=None) + self.db.row_factory = sqlite3.Row self.cursor = self.db.cursor() - def insert_or_select_cmd(self, name: str) -> int: - results = self.cursor.execute("SELECT id, name FROM commands WHERE name = ?", (name,)).fetchall() - - if len(results) == 0: - self.cursor.execute("INSERT INTO commands (name) VALUES (?)", (name,)) - return self.cursor.lastrowid - elif len(results) == 1: - return results[0][0] - else: - print("this should not be happening: " + str(results)) - return -1 - def setup_db(self): # create tables self.cursor.execute(""" @@ -40,228 +117,175 @@ def setup_db(self): tag TEXT, started_at text, stopped_at text, - rounds INTEGER, configuration TEXT ) """) self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS commands ( - id INTEGER PRIMARY KEY, - name string unique - ) - """) - self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS queries ( + CREATE TABLE IF NOT EXISTS sections ( run_id INTEGER, - round INTEGER, - cmd_id INTEGER, - query TEXT, - response TEXT, + id INTEGER, + name TEXT, + from_message INTEGER, + to_message INTEGER, duration REAL, - tokens_query INTEGER, - tokens_response INTEGER, - prompt TEXT, - answer TEXT + PRIMARY KEY (run_id, id), + FOREIGN KEY (run_id) REFERENCES runs (id) ) """) self.cursor.execute(""" CREATE TABLE IF NOT EXISTS messages ( run_id INTEGER, - message_id INTEGER, + conversation TEXT, + id INTEGER, + version INTEGER DEFAULT 0, role TEXT, content TEXT, duration REAL, tokens_query INTEGER, - tokens_response INTEGER + tokens_response INTEGER, + PRIMARY KEY (run_id, id), + FOREIGN KEY (run_id) REFERENCES runs (id) ) """) self.cursor.execute(""" CREATE TABLE IF NOT EXISTS tool_calls ( run_id INTEGER, message_id INTEGER, - tool_call_id INTEGER, + id TEXT, + version INTEGER DEFAULT 0, function_name TEXT, arguments TEXT, + state TEXT, result_text TEXT, - duration REAL + duration REAL, + PRIMARY KEY (run_id, message_id, id), + FOREIGN KEY (run_id, message_id) REFERENCES messages (run_id, id) ) """) - # insert commands - self.query_cmd_id = self.insert_or_select_cmd("query_cmd") - self.analyze_response_id = self.insert_or_select_cmd("analyze_response") - self.state_update_id = self.insert_or_select_cmd("update_state") + def get_runs(self) -> list[Run]: + def deserialize(row): + row = dict(row) + row["started_at"] = datetime.datetime.fromisoformat(row["started_at"]) + row["stopped_at"] = datetime.datetime.fromisoformat(row["stopped_at"]) if row["stopped_at"] else None + return row - def create_new_run(self, model, tag): + self.cursor.execute("SELECT * FROM runs") + return [Run(**deserialize(row)) for row in self.cursor.fetchall()] + + def get_sections_by_run(self, run_id: int) -> list[Section]: + def deserialize(row): + row = dict(row) + row["duration"] = datetime.timedelta(seconds=row["duration"]) + return row + + self.cursor.execute("SELECT * FROM sections WHERE run_id = ?", (run_id,)) + return [Section(**deserialize(row)) for row in self.cursor.fetchall()] + + def get_messages_by_run(self, run_id: int) -> list[Message]: + def deserialize(row): + row = dict(row) + row["duration"] = datetime.timedelta(seconds=row["duration"]) + return row + + self.cursor.execute("SELECT * FROM messages WHERE run_id = ?", (run_id,)) + return [Message(**deserialize(row)) for row in self.cursor.fetchall()] + + def get_tool_calls_by_run(self, run_id: int) -> list[ToolCall]: + def deserialize(row): + row = dict(row) + row["duration"] = datetime.timedelta(seconds=row["duration"]) + return row + + self.cursor.execute("SELECT * FROM tool_calls WHERE run_id = ?", (run_id,)) + return [ToolCall(**deserialize(row)) for row in self.cursor.fetchall()] + + def create_run(self, model: str, tag: str, started_at: datetime.datetime, configuration: str) -> int: self.cursor.execute( - "INSERT INTO runs (model, state, tag, started_at) VALUES (?, ?, ?, datetime('now'))", - (model, "in progress", tag), + "INSERT INTO runs (model, state, tag, started_at, configuration) VALUES (?, ?, ?, ?, ?)", + (model, "in progress", tag, started_at, configuration), ) return self.cursor.lastrowid - def add_log_query(self, run_id, round, cmd, result, answer): + def add_message(self, run_id: int, message_id: int, conversation: Optional[str], role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta): self.cursor.execute( - "INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - run_id, - round, - self.query_cmd_id, - cmd, - result, - answer.duration, - answer.tokens_query, - answer.tokens_response, - answer.prompt, - answer.answer, - ), + "INSERT INTO messages (run_id, conversation, id, role, content, tokens_query, tokens_response, duration) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (run_id, conversation, message_id, role, content, tokens_query, tokens_response, duration.total_seconds()) ) - def add_log_analyze_response(self, run_id, round, cmd, result, answer): + def add_or_update_message(self, run_id: int, message_id: int, conversation: Optional[str], role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta): self.cursor.execute( - "INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - run_id, - round, - self.analyze_response_id, - cmd, - result, - answer.duration, - answer.tokens_query, - answer.tokens_response, - answer.prompt, - answer.answer, - ), + "SELECT COUNT(*) FROM messages WHERE run_id = ? AND id = ?", + (run_id, message_id), ) - - def add_log_update_state(self, run_id, round, cmd, result, answer): - if answer is not None: + if self.cursor.fetchone()[0] == 0: self.cursor.execute( - "INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - run_id, - round, - self.state_update_id, - cmd, - result, - answer.duration, - answer.tokens_query, - answer.tokens_response, - answer.prompt, - answer.answer, - ), + "INSERT INTO messages (run_id, conversation, id, role, content, tokens_query, tokens_response, duration) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (run_id, conversation, message_id, role, content, tokens_query, tokens_response, duration.total_seconds()), ) else: - self.cursor.execute( - "INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (run_id, round, self.state_update_id, cmd, result, 0, 0, 0, "", ""), - ) + if len(content) > 0: + self.cursor.execute( + "UPDATE messages SET conversation = ?, role = ?, content = ?, tokens_query = ?, tokens_response = ?, duration = ? WHERE run_id = ? AND id = ?", + (conversation, role, content, tokens_query, tokens_response, duration.total_seconds(), run_id, message_id), + ) + else: + self.cursor.execute( + "UPDATE messages SET conversation = ?, role = ?, tokens_query = ?, tokens_response = ?, duration = ? WHERE run_id = ? AND id = ?", + (conversation, role, tokens_query, tokens_response, duration.total_seconds(), run_id, message_id), + ) - def add_log_message(self, run_id: int, role: str, content: str, tokens_query: int, tokens_response: int, duration): + def add_section(self, run_id: int, section_id: int, name: str, from_message: int, to_message: int, duration: datetime.timedelta): self.cursor.execute( - "INSERT INTO messages (run_id, message_id, role, content, tokens_query, tokens_response, duration) VALUES (?, (SELECT COALESCE(MAX(message_id), 0) + 1 FROM messages WHERE run_id = ?), ?, ?, ?, ?, ?)", - (run_id, run_id, role, content, tokens_query, tokens_response, duration), + "INSERT OR REPLACE INTO sections (run_id, id, name, from_message, to_message, duration) VALUES (?, ?, ?, ?, ?, ?)", + (run_id, section_id, name, from_message, to_message, duration.total_seconds()) ) - self.cursor.execute("SELECT MAX(message_id) FROM messages WHERE run_id = ?", (run_id,)) - return self.cursor.fetchone()[0] - - def add_log_tool_call( - self, - run_id: int, - message_id: int, - tool_call_id: str, - function_name: str, - arguments: str, - result_text: str, - duration, - ): + + def add_tool_call(self, run_id: int, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration: datetime.timedelta): self.cursor.execute( - "INSERT INTO tool_calls (run_id, message_id, tool_call_id, function_name, arguments, result_text, duration) VALUES (?, ?, ?, ?, ?, ?, ?)", - (run_id, message_id, tool_call_id, function_name, arguments, result_text, duration), + "INSERT INTO tool_calls (run_id, message_id, id, function_name, arguments, result_text, duration) VALUES (?, ?, ?, ?, ?, ?, ?)", + (run_id, message_id, tool_call_id, function_name, arguments, result_text, duration.total_seconds()), ) - def get_round_data(self, run_id, round, explanation, status_update): - rows = self.cursor.execute( - "select cmd_id, query, response, duration, tokens_query, tokens_response from queries where run_id = ? and round = ?", - (run_id, round), - ).fetchall() - if len(rows) == 0: - return [] - - for row in rows: - if row[0] == self.query_cmd_id: - cmd = row[1] - size_resp = str(len(row[2])) - duration = f"{row[3]:.4f}" - tokens = f"{row[4]}/{row[5]}" - if row[0] == self.analyze_response_id and explanation: - reason = row[2] - analyze_time = f"{row[3]:.4f}" - analyze_token = f"{row[4]}/{row[5]}" - if row[0] == self.state_update_id and status_update: - state_time = f"{row[3]:.4f}" - state_token = f"{row[4]}/{row[5]}" - - result = [duration, tokens, cmd, size_resp] - if explanation: - result += [analyze_time, analyze_token, reason] - if status_update: - result += [state_time, state_token] - return result - - def get_max_round_for(self, run_id): - run = self.cursor.execute("select max(round) from queries where run_id = ?", (run_id,)).fetchone() - if run is not None: - return run[0] - else: - return None + def handle_message_update(self, run_id: int, message_id: int, action: StreamAction, content: str): + if action != "append": + raise ValueError("unsupported action" + action) + self.cursor.execute( + "UPDATE messages SET content = content || ?, version = version + 1 WHERE run_id = ? AND id = ?", + (content, run_id, message_id), + ) - def get_run_data(self, run_id): - run = self.cursor.execute("select * from runs where id = ?", (run_id,)).fetchone() - if run is not None: - return run[1], run[2], run[4], run[3], run[7], run[8] + def finalize_message(self, run_id: int, message_id: int, tokens_query: int, tokens_response: int, duration: datetime.timedelta, overwrite_finished_message: Optional[str] = None): + if overwrite_finished_message: + self.cursor.execute( + "UPDATE messages SET content = ?, tokens_query = ?, tokens_response = ?, duration = ? WHERE run_id = ? AND id = ?", + (overwrite_finished_message, tokens_query, tokens_response, duration.total_seconds(), run_id, message_id), + ) else: - return None - - def get_log_overview(self): - result = {} - - max_rounds = self.cursor.execute("select run_id, max(round) from queries group by run_id").fetchall() - for row in max_rounds: - state = self.cursor.execute("select state from runs where id = ?", (row[0],)).fetchone() - last_cmd = self.cursor.execute( - "select query from queries where run_id = ? and round = ?", (row[0], row[1]) - ).fetchone() - - result[row[0]] = {"max_round": int(row[1]) + 1, "state": state[0], "last_cmd": last_cmd[0]} - - return result - - def get_cmd_history(self, run_id): - rows = self.cursor.execute( - "select query, response from queries where run_id = ? and cmd_id = ? order by round asc", - (run_id, self.query_cmd_id), - ).fetchall() - - result = [] - - for row in rows: - result.append([row[0], row[1]]) + self.cursor.execute( + "UPDATE messages SET tokens_query = ?, tokens_response = ?, duration = ? WHERE run_id = ? AND id = ?", + (tokens_query, tokens_response, duration.total_seconds(), run_id, message_id), + ) - return result + def update_run(self, run_id: int, model: str, state: str, tag: str, started_at: datetime.datetime, stopped_at: datetime.datetime, configuration: str): + self.cursor.execute( + "UPDATE runs SET model = ?, state = ?, tag = ?, started_at = ?, stopped_at = ?, configuration = ? WHERE id = ?", + (model, state, tag, started_at, stopped_at, configuration, run_id), + ) - def run_was_success(self, run_id, round): + def run_was_success(self, run_id): self.cursor.execute( - "update runs set state=?,stopped_at=datetime('now'), rounds=? where id = ?", - ("got root", round, run_id), + "update runs set state=?,stopped_at=datetime('now') where id = ?", + ("got root", run_id), ) self.db.commit() - def run_was_failure(self, run_id, round): + def run_was_failure(self, run_id: int, reason: str): self.cursor.execute( - "update runs set state=?, stopped_at=datetime('now'), rounds=? where id = ?", - ("reached max runs", round, run_id), + "update runs set state=?, stopped_at=datetime('now') where id = ?", + (reason, run_id), ) self.db.commit() - def commit(self): - self.db.commit() + +DbStorage = Global(RawDbStorage) diff --git a/src/hackingBuddyGPT/utils/llm_util.py b/src/hackingBuddyGPT/utils/llm_util.py index 80b9480..fc04dc6 100644 --- a/src/hackingBuddyGPT/utils/llm_util.py +++ b/src/hackingBuddyGPT/utils/llm_util.py @@ -1,4 +1,5 @@ import abc +import datetime import re import typing from dataclasses import dataclass @@ -20,7 +21,7 @@ class LLMResult: result: typing.Any prompt: str answer: str - duration: float = 0 + duration: datetime.timedelta = datetime.timedelta(0) tokens_query: int = 0 tokens_response: int = 0 diff --git a/src/hackingBuddyGPT/utils/logging.py b/src/hackingBuddyGPT/utils/logging.py new file mode 100644 index 0000000..0ec0ca5 --- /dev/null +++ b/src/hackingBuddyGPT/utils/logging.py @@ -0,0 +1,381 @@ +import datetime +from enum import Enum +import time +from dataclasses import dataclass, field +from functools import wraps +from typing import Optional, Union +import threading + +from dataclasses_json.api import dataclass_json + +from hackingBuddyGPT.utils import configurable, DbStorage, Console, LLMResult +from hackingBuddyGPT.utils.db_storage.db_storage import StreamAction +from hackingBuddyGPT.utils.configurable import Global, Transparent +from rich.console import Group +from rich.panel import Panel +from websockets.sync.client import ClientConnection, connect as ws_connect + +from hackingBuddyGPT.utils.db_storage.db_storage import Run, Section, Message, MessageStreamPart, ToolCall, ToolCallStreamPart + + +def log_section(name: str, logger_field_name: str = "log"): + def outer(fun): + @wraps(fun) + def inner(self, *args, **kwargs): + logger = getattr(self, logger_field_name) + with logger.section(name): + return fun(self, *args, **kwargs) + return inner + return outer + + +def log_conversation(conversation: str, start_section: bool = False, logger_field_name: str = "log"): + def outer(fun): + @wraps(fun) + def inner(self, *args, **kwargs): + logger = getattr(self, logger_field_name) + with logger.conversation(conversation, start_section): + return fun(self, *args, **kwargs) + return inner + return outer + + +MessageData = Union[Run, Section, Message, MessageStreamPart, ToolCall, ToolCallStreamPart] + + +class MessageType(str, Enum): + MESSAGE_REQUEST = "MessageRequest" + RUN = "Run" + SECTION = "Section" + MESSAGE = "Message" + MESSAGE_STREAM_PART = "MessageStreamPart" + TOOL_CALL = "ToolCall" + TOOL_CALL_STREAM_PART = "ToolCallStreamPart" + + def get_class(self): + return { + "Run": Run, + "Section": Section, + "Message": Message, + "MessageStreamPart": MessageStreamPart, + "ToolCall": ToolCall, + "ToolCallStreamPart": ToolCallStreamPart, + }[self.value] + + +@dataclass_json +@dataclass +class ControlMessage: + type: MessageType + data: MessageData + + @classmethod + def from_dict(cls, data): + type_ = MessageType(data['type']) + data_class = type_.get_class() + data_instance = data_class.from_dict(data['data']) + return cls(type=type_, data=data_instance) + + +@configurable("logger", "Logger") +@dataclass +class Logger: + log_db: DbStorage + console: Console + tag: str = "" + + run: Run = field(init=False, default=None) + + _last_message_id: int = 0 + _last_section_id: int = 0 + _current_conversation: Optional[str] = None + + def start_run(self, name: str, configuration: str): + if self.run is not None: + raise ValueError("Run already started") + start_time = datetime.datetime.now() + run_id = self.log_db.create_run(name, self.tag, start_time , configuration) + self.run = Run(run_id, name, "", self.tag, start_time, None, configuration) + + def section(self, name: str) -> "LogSectionContext": + return LogSectionContext(self, name, self._last_message_id) + + def log_section(self, name: str, from_message: int, to_message: int, duration: datetime.timedelta): + section_id = self._last_section_id + self._last_section_id += 1 + + self.log_db.add_section(self.run.id, section_id, name, from_message, to_message, duration) + + return section_id + + def finalize_section(self, section_id: int, name: str, from_message: int, duration: datetime.timedelta): + self.log_db.add_section(self.run.id, section_id, name, from_message, self._last_message_id, duration) + + def conversation(self, conversation: str, start_section: bool = False) -> "LogConversationContext": + return LogConversationContext(self, start_section, conversation, self._current_conversation) + + def add_message(self, role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta) -> int: + message_id = self._last_message_id + self._last_message_id += 1 + + self.log_db.add_message(self.run.id, message_id, self._current_conversation, role, content, tokens_query, tokens_response, duration) + self.console.print(Panel(content, title=(("" if self._current_conversation is None else f"{self._current_conversation} - ") + role))) + + return message_id + + def _add_or_update_message(self, message_id: int, conversation: Optional[str], role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta): + self.log_db.add_or_update_message(self.run.id, message_id, conversation, role, content, tokens_query, tokens_response, duration) + + def add_tool_call(self, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration: datetime.timedelta): + self.console.print(Panel( + Group( + Panel(arguments, title="arguments"), + Panel(result_text, title="result"), + ), + title=f"Tool Call: {function_name}")) + self.log_db.add_tool_call(self.run.id, message_id, tool_call_id, function_name, arguments, result_text, duration) + + def run_was_success(self): + self.status_message("Run finished successfully") + self.log_db.run_was_success(self.run.id) + + def run_was_failure(self, reason: str, details: Optional[str] = None): + full_reason = reason + ("" if details is None else f": {details}") + self.status_message(f"Run failed: {full_reason}") + self.log_db.run_was_failure(self.run.id, reason) + + def status_message(self, message: str): + self.add_message("status", message, 0, 0, datetime.timedelta(0)) + + def system_message(self, message: str): + self.add_message("system", message, 0, 0, datetime.timedelta(0)) + + def call_response(self, llm_result: LLMResult) -> int: + self.system_message(llm_result.prompt) + return self.add_message("assistant", llm_result.answer, llm_result.tokens_query, llm_result.tokens_response, llm_result.duration) + + def stream_message(self, role: str): + message_id = self._last_message_id + self._last_message_id += 1 + + return MessageStreamLogger(self, message_id, self._current_conversation, role) + + def add_message_update(self, message_id: int, action: StreamAction, content: str): + self.log_db.handle_message_update(self.run.id, message_id, action, content) + + +@configurable("logger", "Logger") +@dataclass +class RemoteLogger: + console: Console + log_server_address: str = "localhost:4444" + + tag: str = "" + run: Run = field(init=False, default=None) + + _last_message_id: int = 0 + _last_section_id: int = 0 + _current_conversation: Optional[str] = None + _upstream_websocket: ClientConnection = None + _keepalive_thread: Optional[threading.Thread] = None + _keepalive_stop_event: threading.Event = field(init=False, default_factory=threading.Event) + + def __del__(self): + print("running log deleter") + if self._upstream_websocket: + self._upstream_websocket.close() + if self._keepalive_thread: + self._keepalive_stop_event.set() + self._keepalive_thread.join() + + def init_websocket(self): + self._upstream_websocket = ws_connect(f"ws://{self.log_server_address}/ingress") # TODO: we want to support wss at some point + # self.start_keepalive() + + def start_keepalive(self): + self._keepalive_stop_event.clear() + self._keepalive_thread = threading.Thread(target=self.keepalive) + self._keepalive_thread.start() + + def keepalive(self): + while not self._keepalive_stop_event.is_set(): + try: + self._upstream_websocket.ping() + self._upstream_websocket.pong() + except Exception as e: + import traceback + traceback.print_exc() + print("Keepalive error:", e) + self._keepalive_stop_event.set() + time.sleep(5) + + def send(self, type: MessageType, data: MessageData): + self._upstream_websocket.send(ControlMessage(type, data).to_json()) + + def start_run(self, name: str, configuration: str, tag: Optional[str] = None, start_time: Optional[datetime.datetime] = None, end_time: Optional[datetime.datetime] = None): + if self._upstream_websocket is None: + self.init_websocket() + + if self.run is not None: + raise ValueError("Run already started") + + if tag is None: + tag = self.tag + + if start_time is None: + start_time = datetime.datetime.now() + + self.run = Run(None, name, None, tag, start_time, None, configuration) + self.send(MessageType.RUN, self.run) + self.run = Run.from_json(self._upstream_websocket.recv()) + + def section(self, name: str) -> "LogSectionContext": + return LogSectionContext(self, name, self._last_message_id) + + def log_section(self, name: str, from_message: int, to_message: int, duration: datetime.timedelta): + section_id = self._last_section_id + self._last_section_id += 1 + + section = Section(self.run.id, section_id, name, from_message, to_message, duration) + self.send(MessageType.SECTION, section) + + return section_id + + def finalize_section(self, section_id: int, name: str, from_message: int, duration: datetime.timedelta): + self.send(MessageType.SECTION, Section(self.run.id, section_id, name, from_message, self._last_message_id, duration)) + + def conversation(self, conversation: str, start_section: bool = False) -> "LogConversationContext": + return LogConversationContext(self, start_section, conversation, self._current_conversation) + + def add_message(self, role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta) -> int: + message_id = self._last_message_id + self._last_message_id += 1 + + msg = Message(self.run.id, message_id, version=1, conversation=self._current_conversation, role=role, content=content, duration=duration, tokens_query=tokens_query, tokens_response=tokens_response) + self.send(MessageType.MESSAGE, msg) + self.console.print(Panel(content, title=(("" if self._current_conversation is None else f"{self._current_conversation} - ") + role))) + + return message_id + + def _add_or_update_message(self, message_id: int, conversation: Optional[str], role: str, content: str, tokens_query: int, tokens_response: int, duration: datetime.timedelta): + msg = Message(self.run.id, message_id, version=0, conversation=conversation, role=role, content=content, duration=duration, tokens_query=tokens_query, tokens_response=tokens_response) + self.send(MessageType.MESSAGE, msg) + + def add_tool_call(self, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration: datetime.timedelta): + self.console.print(Panel( + Group( + Panel(arguments, title="arguments"), + Panel(result_text, title="result"), + ), + title=f"Tool Call: {function_name}")) + tc = ToolCall(self.run.id, message_id, tool_call_id, 0, function_name, arguments, "success", result_text, duration) + self.send(MessageType.TOOL_CALL, tc) + + def run_was_success(self): + self.status_message("Run finished successfully") + self.run.stopped_at = datetime.datetime.now() + self.run.state = "success" + self.send(MessageType.RUN, self.run) + self.run = Run.from_json(self._upstream_websocket.recv()) + + def run_was_failure(self, reason: str, details: Optional[str] = None): + full_reason = reason + ("" if details is None else f": {details}") + self.status_message(f"Run failed: {full_reason}") + self.run.stopped_at = datetime.datetime.now() + self.run.state = reason + self.send(MessageType.RUN, self.run) + self.run = Run.from_json(self._upstream_websocket.recv()) + + def status_message(self, message: str): + self.add_message("status", message, 0, 0, datetime.timedelta(0)) + + def system_message(self, message: str): + self.add_message("system", message, 0, 0, datetime.timedelta(0)) + + def call_response(self, llm_result: LLMResult) -> int: + self.system_message(llm_result.prompt) + return self.add_message("assistant", llm_result.answer, llm_result.tokens_query, llm_result.tokens_response, llm_result.duration) + + def stream_message(self, role: str): + message_id = self._last_message_id + self._last_message_id += 1 + + return MessageStreamLogger(self, message_id, self._current_conversation, role) + + def add_message_update(self, message_id: int, action: StreamAction, content: str): + part = MessageStreamPart(id=None, run_id=self.run.id, message_id=message_id, action=action, content=content) + self.send(MessageType.MESSAGE_STREAM_PART, part) + + +@dataclass +class LogSectionContext: + logger: Logger + name: str + from_message: int + + _section_id: int = 0 + + def __enter__(self): + self._start = datetime.datetime.now() + self._section_id = self.logger.log_section(self.name, self.from_message, None, datetime.timedelta(0)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + duration = datetime.datetime.now() - self._start + self.logger.finalize_section(self._section_id, self.name, self.from_message, duration) + + +@dataclass +class LogConversationContext: + logger: Logger + with_section: bool + conversation: str + previous_conversation: Optional[str] + + _section: Optional[LogSectionContext] = None + + def __enter__(self): + if self.with_section: + self._section = LogSectionContext(self.logger, self.conversation, self.logger._last_message_id) + self._section.__enter__() + self.logger._current_conversation = self.conversation + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._section is not None: + self._section.__exit__(exc_type, exc_val, exc_tb) + del self._section + self.logger._current_conversation = self.previous_conversation + + +@dataclass +class MessageStreamLogger: + logger: Logger + message_id: int + conversation: Optional[str] + role: str + + _completed: bool = False + + def __post_init__(self): + self.logger._add_or_update_message(self.message_id, self.conversation, self.role, "", 0, 0, datetime.timedelta(0)) + + def __del__(self): + if not self._completed: + print(f"streamed message was not finalized ({self.logger.run.id}, {self.message_id}), please make sure to call finalize() on MessageStreamLogger objects") + self.finalize(0, 0, datetime.timedelta(0)) + + def append(self, content: str): + if self._completed: + raise ValueError("MessageStreamLogger already finalized") + self.logger.add_message_update(self.message_id, "append", content) + + def finalize(self, tokens_query: int, tokens_response: int, duration: datetime.timedelta, overwrite_finished_message: Optional[str] = None): + self._completed = True + self.logger._add_or_update_message(self.message_id, self.conversation, self.role, "", tokens_query, tokens_response, duration) + return self.message_id + + +GlobalLocalLogger = Global(Transparent(Logger)) +GlobalRemoteLogger = Global(Transparent(RemoteLogger)) +GlobalLogger = GlobalRemoteLogger diff --git a/src/hackingBuddyGPT/utils/openai/openai_lib.py b/src/hackingBuddyGPT/utils/openai/openai_lib.py index 654799d..5059083 100644 --- a/src/hackingBuddyGPT/utils/openai/openai_lib.py +++ b/src/hackingBuddyGPT/utils/openai/openai_lib.py @@ -1,10 +1,11 @@ -import time +import datetime from dataclasses import dataclass from typing import Dict, Iterable, Optional, Union import instructor import openai import tiktoken +from dataclasses import dataclass from openai.types import CompletionUsage from openai.types.chat import ( ChatCompletionChunk, @@ -12,6 +13,7 @@ ChatCompletionMessageParam, ChatCompletionMessageToolCall, ) +from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.chat.chat_completion_message_tool_call import Function from rich.console import Console @@ -49,8 +51,8 @@ def client(self) -> openai.OpenAI: def instructor(self) -> instructor.Instructor: return instructor.from_openai(self.client) - def get_response(self, prompt, *, capabilities: Dict[str, Capability] = None, **kwargs) -> LLMResult: - """# TODO: re-enable compatibility layer + def get_response(self, prompt, *, capabilities: Optional[Dict[str, Capability] ] = None, **kwargs) -> LLMResult: + """ # TODO: re-enable compatibility layer if isinstance(prompt, str) or hasattr(prompt, "render"): prompt = {"role": "user", "content": prompt} @@ -66,35 +68,38 @@ def get_response(self, prompt, *, capabilities: Dict[str, Capability] = None, ** if capabilities: tools = capabilities_to_tools(capabilities) - tic = time.perf_counter() + tic = datetime.datetime.now() response = self._client.chat.completions.create( model=self.model, messages=prompt, tools=tools, ) - toc = time.perf_counter() + duration = datetime.datetime.now() - tic message = response.choices[0].message return LLMResult( message, str(prompt), message.content, - toc - tic, + duration, response.usage.prompt_tokens, response.usage.completion_tokens, ) - def stream_response( - self, - prompt: Iterable[ChatCompletionMessageParam], - console: Console, - capabilities: Dict[str, Capability] = None, - ) -> Iterable[Union[ChatCompletionChunk, LLMResult]]: + def stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: Console, capabilities: Dict[str, Capability] = None, get_individual_updates=False) -> Union[LLMResult, Iterable[Union[ChoiceDelta, LLMResult]]]: + generator = self._stream_response(prompt, console, capabilities) + + if get_individual_updates: + return generator + + return list(generator)[-1] + + def _stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: Console, capabilities: Dict[str, Capability] = None) -> Iterable[Union[ChoiceDelta, LLMResult]]: tools = None if capabilities: tools = capabilities_to_tools(capabilities) - tic = time.perf_counter() + tic = datetime.datetime.now() chunks = self._client.chat.completions.create( model=self.model, messages=prompt, @@ -149,12 +154,13 @@ def stream_response( message.tool_calls[tool_call.index].function.arguments += tool_call.function.arguments outputs += 1 + yield delta + if chunk.usage is not None: usage = chunk.usage if outputs > 1: print("WARNING: Got more than one output in the stream response") - yield chunk console.print() if usage is None: @@ -164,7 +170,7 @@ def stream_response( if len(message.tool_calls) == 0: # the openAI API does not like getting empty tool call lists message.tool_calls = None - toc = time.perf_counter() + toc = datetime.datetime.now() yield LLMResult( message, str(prompt), @@ -173,7 +179,6 @@ def stream_response( usage.prompt_tokens, usage.completion_tokens, ) - pass def encode(self, query) -> list[int]: return tiktoken.encoding_for_model(self.model).encode(query) diff --git a/src/hackingBuddyGPT/utils/openai/openai_llm.py b/src/hackingBuddyGPT/utils/openai/openai_llm.py index 7553ee0..29193f9 100644 --- a/src/hackingBuddyGPT/utils/openai/openai_llm.py +++ b/src/hackingBuddyGPT/utils/openai/openai_llm.py @@ -1,4 +1,5 @@ import time +import datetime from dataclasses import dataclass import requests @@ -41,7 +42,7 @@ def get_response(self, prompt, *, retry: int = 0, **kwargs) -> LLMResult: data = {"model": self.model, "messages": [{"role": "user", "content": prompt}]} try: - tic = time.perf_counter() + tic = datetime.datetime.now() response = requests.post(f'{self.api_url}{self.api_path}', headers=headers, json=data, timeout=self.api_timeout) if response.status_code == 429: @@ -63,18 +64,18 @@ def get_response(self, prompt, *, retry: int = 0, **kwargs) -> LLMResult: # now extract the JSON status message # TODO: error handling.. - toc = time.perf_counter() response = response.json() result = response["choices"][0]["message"]["content"] tok_query = response["usage"]["prompt_tokens"] tok_res = response["usage"]["completion_tokens"] + duration = datetime.datetime.now() - tic - return LLMResult(result, prompt, result, toc - tic, tok_query, tok_res) + return LLMResult(result, prompt, result, duration, tok_query, tok_res) def encode(self, query) -> list[int]: # I know this is crappy for all non-openAI models but sadly this # has to be good enough for now - if self.model.startswith("gpt-"): + if self.model.startswith("gpt-") and not self.model.startswith("gpt-4o"): encoding = tiktoken.encoding_for_model(self.model) else: encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") diff --git a/tests/integration_minimal_test.py b/tests/integration_minimal_test.py index c6f00e9..fe4a8e7 100644 --- a/tests/integration_minimal_test.py +++ b/tests/integration_minimal_test.py @@ -1,5 +1,6 @@ from typing import Tuple +from hackingBuddyGPT.utils.logging import Logger from hackingBuddyGPT.usecases.examples.agent import ( ExPrivEscLinux, ExPrivEscLinuxUseCase, @@ -80,6 +81,11 @@ def test_linuxprivesc(): log_db.init() + log = Logger( + log_db=log_db, + console=console, + tag="integration_test_linuxprivesc", + ) priv_esc = LinuxPrivescUseCase( agent=LinuxPrivesc( conn=conn, @@ -87,14 +93,13 @@ def test_linuxprivesc(): disable_history=False, hint="", llm=llm, + log=log, ), - log_db=log_db, - console=console, - tag="integration_test_linuxprivesc", + log=log, max_turns=len(llm.responses), ) - priv_esc.init() + priv_esc.init({}) result = priv_esc.run() assert result is True @@ -107,15 +112,18 @@ def test_minimal_agent(): log_db.init() - priv_esc = ExPrivEscLinuxUseCase( - agent=ExPrivEscLinux(conn=conn, llm=llm), + log = Logger( log_db=log_db, console=console, tag="integration_test_minimallinuxprivesc", - max_turns=len(llm.responses), + ) + priv_esc = ExPrivEscLinuxUseCase( + agent=ExPrivEscLinux(conn=conn, llm=llm, log=log), + log=log, + max_turns=len(llm.responses) ) - priv_esc.init() + priv_esc.init({}) result = priv_esc.run() assert result is True @@ -128,14 +136,17 @@ def test_minimal_agent_state(): log_db.init() - priv_esc = ExPrivEscLinuxTemplatedUseCase( - agent=ExPrivEscLinuxTemplated(conn=conn, llm=llm), + log = Logger( log_db=log_db, console=console, tag="integration_test_linuxprivesc", - max_turns=len(llm.responses), + ) + priv_esc = ExPrivEscLinuxTemplatedUseCase( + agent=ExPrivEscLinuxTemplated(conn=conn, llm=llm, log=log), + log=log, + max_turns=len(llm.responses) ) - priv_esc.init() + priv_esc.init({}) result = priv_esc.run() assert result is True diff --git a/tests/test_web_api_documentation.py b/tests/test_web_api_documentation.py index f26afea..8b95d88 100644 --- a/tests/test_web_api_documentation.py +++ b/tests/test_web_api_documentation.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import MagicMock, patch +from hackingBuddyGPT.utils.logging import Logger from hackingBuddyGPT.usecases.web_api_testing.simple_openapi_documentation import ( SimpleWebAPIDocumentation, SimpleWebAPIDocumentationUseCase, @@ -17,16 +18,19 @@ def setUp(self, MockOpenAILib): console = Console() log_db.init() - self.agent = SimpleWebAPIDocumentation(llm=self.mock_llm) - self.agent.init() - self.simple_api_testing = SimpleWebAPIDocumentationUseCase( - agent=self.agent, + log = Logger( log_db=log_db, console=console, tag="webApiDocumentation", + ) + self.agent = SimpleWebAPIDocumentation(llm=self.mock_llm, log=log) + self.agent.init() + self.simple_api_testing = SimpleWebAPIDocumentationUseCase( + agent=self.agent, + log=log, max_turns=len(self.mock_llm.responses), ) - self.simple_api_testing.init() + self.simple_api_testing.init({}) def test_initial_prompt(self): # Test if the initial prompt is set correctly diff --git a/tests/test_web_api_testing.py b/tests/test_web_api_testing.py index 84137e5..2071ca8 100644 --- a/tests/test_web_api_testing.py +++ b/tests/test_web_api_testing.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch from hackingBuddyGPT.usecases import SimpleWebAPITesting +from hackingBuddyGPT.utils.logging import Logger from hackingBuddyGPT.usecases.web_api_testing.simple_web_api_testing import ( SimpleWebAPITestingUseCase, ) @@ -17,16 +18,19 @@ def setUp(self, MockOpenAILib): console = Console() log_db.init() - self.agent = SimpleWebAPITesting(llm=self.mock_llm) - self.agent.init() - self.simple_api_testing = SimpleWebAPITestingUseCase( - agent=self.agent, + log = Logger( log_db=log_db, console=console, tag="integration_test_linuxprivesc", + ) + self.agent = SimpleWebAPITesting(llm=self.mock_llm, log=log) + self.agent.init() + self.simple_api_testing = SimpleWebAPITestingUseCase( + agent=self.agent, + log=log, max_turns=len(self.mock_llm.responses), ) - self.simple_api_testing.init() + self.simple_api_testing.init({}) def test_initial_prompt(self): # Test if the initial prompt is set correctly