Skip to content

Commit

Permalink
feature: add support for loading parent's parent (and so on) and load…
Browse files Browse the repository at this point in the history
…ing import from multiple files
  • Loading branch information
cedric05 committed Oct 18, 2023
1 parent 67fab9a commit e829f10
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 31 deletions.
92 changes: 64 additions & 28 deletions dothttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def __init__(self, args: Config):
def load(self):
self.load_content()
self.load_model()
self.load_imports()
self.load_properties_n_headers()
self.load_command_line_props()
self.validate_names()
Expand Down Expand Up @@ -455,6 +456,22 @@ def load_model(self):
raise HttpFileException(message=e.args)
self.model: MultidefHttp = model


def load_imports(self):
for filename_string in self.model.import_list.filename:
import_file = filename_string.value
if not os.path.exists(import_file):
raise HttpFileNotFoundException(file=import_file)
with open(import_file, 'r', encoding="utf-8") as f:
imported_content = "\n\n#just spacing for easy understanding \n" + f.read()
try:
imported_model = dothttp_model.model_from_str(imported_content)
self.model.allhttps += imported_model.allhttps
except TextXSyntaxError as e:
raise HttpFileSyntaxException(file=self.file, message=e.args)
except Exception as e:
raise HttpFileException(message=e.args)

def load_content(self):
if not os.path.exists(self.file):
raise HttpFileNotFoundException(file=self.file)
Expand All @@ -472,16 +489,23 @@ def select_target(self):
self.http = self.get_target(target, self.model.allhttps)
else:
self.http = self.model.allhttps[0]
self.base_http = None
self.parents_http = []
if self.http.namewrap and self.http.namewrap.base:
base = self.http.namewrap.base
if base == self.http.namewrap.name:
raise ParameterException(message="target and base should not be equal", key=target,
value=base)
parent = self.http.namewrap.base
try:
self.base_http = self.get_target(base, self.model.allhttps)
if parent == self.http.namewrap.name:
raise ParameterException(message="target and base should not be equal", key=target,
value=parent)
while parent:
if parent in self.parents_http:
raise ParameterException(message="Found circular reference", target=self.http.namewrap.name)
grand_http = self.get_target(parent, self.model.allhttps)
self.parents_http.append(grand_http)
parent = grand_http.namewrap.base
except Exception:
raise UndefinedHttpToExtend(target=self.http.namewrap.name, base=base)
raise UndefinedHttpToExtend(target=self.http.namewrap.name, base=parent)
for i in self.parents_http:
print(i.namewrap.name)

@staticmethod
def get_target(target: Union[str, int], http_def_list: List[Http]):
Expand Down Expand Up @@ -525,6 +549,10 @@ def __init__(self, args: Config):

def load_query(self):
params: DefaultDict[List] = defaultdict(list)
for parent in self.parents_http:
for line in parent.lines:
if query := line.query:
params[self.get_updated_content(query.key)].append(self.get_updated_content(query.value))
for line in self.http.lines:
if query := line.query:
params[self.get_updated_content(query.key)].append(self.get_updated_content(query.value))
Expand All @@ -550,7 +578,8 @@ def load_headers(self):
## having duplicate headers creates problem while exporting to curl,postman import..
headers = CaseInsensitiveDict()
headers.update(self.default_headers)
self.load_headers_to_dict(self.base_http, headers)
for parent in self.parents_http:
self.load_headers_to_dict(parent, headers)
self.load_headers_to_dict(self.http, headers)
request_logger.debug(
f'computed query params from `{self.file}` are `{headers}`')
Expand Down Expand Up @@ -581,9 +610,11 @@ def load_extra_flags(self):
# flags are extendable
# once its marked as allow insecure
# user would want all child to have same effect
extra_args = self.http.extra_args
if self.base_http and self.base_http.extra_args:
extra_args += self.base_http.extra_args
extra_args = list(self.http.extra_args)
if self.parents_http:
for parent in self.parents_http:
if parent.extra_args:
extra_args += parent.extra_args
if extra_args:
for flag in extra_args:
if flag.clear:
Expand All @@ -595,21 +626,24 @@ def load_url(self):
request_logger.debug(
f'url is {self.http.urlwrap.url}')
url_path = self.get_updated_content(self.http.urlwrap.url)
if base_http := self.base_http:
base_url = self.get_updated_content(base_http.urlwrap.url)
if not url_path:
self.httpdef.url = base_url
elif url_path.startswith("http://") or url_path.startswith("https://") or url_path.startswith(
"http+unix://"):
self.httpdef.url = url_path
elif base_url.endswith("/") and url_path.startswith("/"):
self.httpdef.url = urljoin(base_url, url_path[1:])
elif url_path.startswith("/"):
self.httpdef.url = urljoin(base_url + "/", url_path[1:])
elif not base_url.endswith("/") and not url_path.startswith("/"):
self.httpdef.url = urljoin(base_url + "/", url_path)
else:
self.httpdef.url = urljoin(base_url, url_path)
if self.parents_http:
for base_http in self.parents_http:
base_url = self.get_updated_content(base_http.urlwrap.url)
if not url_path:
url = base_url
elif url_path.startswith("http://") or url_path.startswith("https://") or url_path.startswith(
"http+unix://"):
url = url_path
elif base_url.endswith("/") and url_path.startswith("/"):
url = urljoin(base_url, url_path[1:])
elif url_path.startswith("/"):
url = urljoin(base_url + "/", url_path[1:])
elif not base_url.endswith("/") and not url_path.startswith("/"):
url = urljoin(base_url + "/", url_path)
else:
url = urljoin(base_url, url_path)
url_path = url
self.httpdef.url = url_path
else:
self.httpdef.url = url_path
if self.httpdef.url and not (
Expand Down Expand Up @@ -859,8 +893,10 @@ def load_auth(self):
def get_current_or_base(self, attr_key) -> Any:
if getattr(self.http, attr_key):
return getattr(self.http, attr_key)
elif self.base_http:
return getattr(self.base_http, attr_key)
elif self.parents_http:
for parent in self.parents_http:
if getattr(parent, attr_key):
return getattr(parent, attr_key)

def load_def(self):
if self._loaded:
Expand Down
6 changes: 3 additions & 3 deletions dothttp/http.tx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//HTTP: ram=HTTP2
MULTISET: IMPORT allhttps=HTTP+;
MULTISET: (import_list=IMPORT)? allhttps=HTTP+;

IMPORT: ('import' String ';')* ;
IMPORT: ('import' filename=String ';')* ;

HTTP:
(
Expand Down Expand Up @@ -232,4 +232,4 @@ SCRIPT_LANGUAGE:

SLASH:
'\'
;
;

0 comments on commit e829f10

Please sign in to comment.