Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use string enum EntityType #65

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ jobs:
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: '3.12'

- name: Install python libraries
shell: bash
run: |
Expand Down
4 changes: 2 additions & 2 deletions reccmp/isledecomp/compare/asm/replacement.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import cache
from typing import Callable, Protocol
from reccmp.isledecomp.compare.db import ReccmpEntity
from reccmp.isledecomp.types import SymbolType
from reccmp.isledecomp.types import EntityType


class AddrTestProtocol(Protocol):
Expand Down Expand Up @@ -29,7 +29,7 @@ def lookup(addr: int, exact: bool = False) -> str | None:
return m.match_name()

offset = addr - getattr(m, addr_attribute)
if m.compare_type != SymbolType.DATA or offset >= m.size:
if m.entity_type != EntityType.DATA or offset >= m.size:
return None

return m.offset_name(offset)
Expand Down
41 changes: 20 additions & 21 deletions reccmp/isledecomp/compare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from reccmp.isledecomp.cvdump import Cvdump, CvdumpAnalysis
from reccmp.isledecomp.parser import DecompCodebase
from reccmp.isledecomp.dir import walk_source_dir
from reccmp.isledecomp.types import SymbolType
from reccmp.isledecomp.types import EntityType
from reccmp.isledecomp.compare.asm import ParseAsm
from reccmp.isledecomp.compare.asm.replacement import create_name_lookup
from reccmp.isledecomp.compare.asm.fixes import assert_fixup, find_effective_match
Expand All @@ -27,7 +27,7 @@
@dataclass
class DiffReport:
# pylint: disable=too-many-instance-attributes
match_type: SymbolType
match_type: EntityType
orig_addr: int
recomp_addr: int
name: str
Expand Down Expand Up @@ -165,7 +165,7 @@ def _load_cvdump(self):
- sym.offset
)

if sym.node_type == SymbolType.STRING:
if sym.node_type == EntityType.STRING:
assert sym.decorated_name is not None
string_info = demangle_string_const(sym.decorated_name)
if string_info is None:
Expand Down Expand Up @@ -319,7 +319,7 @@ def _add_match_in_array(
# Indexed by recomp addr. Need to preload this data because it is not stored alongside the db rows.
cvdump_lookup = {x.addr: x for x in self.cvdump_analysis.nodes}

for match in self._db.get_matches_by_type(SymbolType.DATA):
for match in self._db.get_matches_by_type(EntityType.DATA):
node = cvdump_lookup.get(match.recomp_addr)
if node is None or node.data_type is None:
continue
Expand Down Expand Up @@ -405,13 +405,13 @@ def is_real_string(s: str) -> bool:
for addr, string in self.orig_bin.iter_string("latin1"):
if is_real_string(string):
self._db.set_orig_symbol(
addr, type=SymbolType.STRING, name=string, size=len(string)
addr, type=EntityType.STRING, name=string, size=len(string)
)

for addr, string in self.recomp_bin.iter_string("latin1"):
if is_real_string(string):
self._db.set_recomp_symbol(
addr, type=SymbolType.STRING, name=string, size=len(string)
addr, type=EntityType.STRING, name=string, size=len(string)
)

def _find_float_const(self):
Expand All @@ -420,12 +420,12 @@ def _find_float_const(self):
deduped like strings."""
for addr, size, float_value in self.orig_bin.find_float_consts():
self._db.set_orig_symbol(
addr, type=SymbolType.FLOAT, name=str(float_value), size=size
addr, type=EntityType.FLOAT, name=str(float_value), size=size
)

for addr, size, float_value in self.recomp_bin.find_float_consts():
self._db.set_recomp_symbol(
addr, type=SymbolType.FLOAT, name=str(float_value), size=size
addr, type=EntityType.FLOAT, name=str(float_value), size=size
)

def _match_imports(self):
Expand Down Expand Up @@ -453,7 +453,7 @@ def _match_imports(self):
continue

# Match the __imp__ symbol
self._db.set_pair(orig, recomp, SymbolType.POINTER)
self._db.set_pair(orig, recomp, EntityType.POINTER)

# Read the relative address from .idata
try:
Expand All @@ -475,9 +475,9 @@ def _match_imports(self):
(dll_name, func_name) = orig_byaddr[orig]
fullname = dll_name + ":" + func_name
self._db.set_recomp_symbol(
recomp_rva, type=SymbolType.FUNCTION, name=fullname, size=4
recomp_rva, type=EntityType.FUNCTION, name=fullname, size=4
)
self._db.set_pair(orig_rva, recomp_rva, SymbolType.FUNCTION)
self._db.set_pair(orig_rva, recomp_rva, EntityType.FUNCTION)
self._db.skip_compare(orig_rva)

def _match_thunks(self):
Expand Down Expand Up @@ -580,7 +580,7 @@ def _find_vtordisp(self):
We could do this differently and check only the original vtable,
construct the name of the vtordisp function and match based on that."""

for match in self._db.get_matches_by_type(SymbolType.VTABLE):
for match in self._db.get_matches_by_type(EntityType.VTABLE):
assert (
match.name is not None
and match.orig_addr is not None
Expand Down Expand Up @@ -743,7 +743,7 @@ def _compare_function(self, match: ReccmpMatch) -> DiffReport:

assert match.name is not None
return DiffReport(
match_type=SymbolType.FUNCTION,
match_type=EntityType.FUNCTION,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
Expand Down Expand Up @@ -831,7 +831,7 @@ def match_text(m: ReccmpEntity | None, raw_addr: int | None = None) -> str:

assert match.name is not None
return DiffReport(
match_type=SymbolType.VTABLE,
match_type=EntityType.VTABLE,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
Expand All @@ -848,21 +848,20 @@ def _compare_match(self, match: ReccmpMatch) -> DiffReport | None:
if match.get("skip", False):
return None

assert match.compare_type is not None
assert match.name is not None
if match.get("stub", False):
return DiffReport(
match_type=SymbolType(match.compare_type),
match_type=match.entity_type,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
is_stub=True,
)

if match.compare_type == SymbolType.FUNCTION:
if match.entity_type == EntityType.FUNCTION:
return self._compare_function(match)

if match.compare_type == SymbolType.VTABLE:
if match.entity_type == EntityType.VTABLE:
return self._compare_vtable(match)

return None
Expand Down Expand Up @@ -892,13 +891,13 @@ def get_all(self) -> Iterator[ReccmpEntity]:
return self._db.get_all()

def get_functions(self) -> Iterator[ReccmpMatch]:
return self._db.get_matches_by_type(SymbolType.FUNCTION)
return self._db.get_matches_by_type(EntityType.FUNCTION)

def get_vtables(self) -> Iterator[ReccmpMatch]:
return self._db.get_matches_by_type(SymbolType.VTABLE)
return self._db.get_matches_by_type(EntityType.VTABLE)

def get_variables(self) -> Iterator[ReccmpMatch]:
return self._db.get_matches_by_type(SymbolType.DATA)
return self._db.get_matches_by_type(EntityType.DATA)

def compare_address(self, addr: int) -> DiffReport | None:
match = self._db.get_one_match(addr)
Expand Down
Loading
Loading