diff --git a/lib/associations.js b/lib/associations.js index beaf2108..5363a9cf 100644 --- a/lib/associations.js +++ b/lib/associations.js @@ -13,10 +13,17 @@ const findAssociatedClass = (association, sourceUmlClass, umlClasses, searchedAb // If a link was found if (umlClass) return umlClass; - // Could not find a link so now need to recursively look at imports of imports + // Could not find association so now need to recursively look at imports of imports // add to already recursively processed files to avoid getting stuck in circular imports searchedAbsolutePaths.push(sourceUmlClass.absolutePath); - return findChainedImport(association, sourceUmlClass, umlClasses, searchedAbsolutePaths); + const importedType = findChainedImport(association, sourceUmlClass, umlClasses, searchedAbsolutePaths); + if (importedType) + return importedType; + // Still could not find association so now need to recursively look for inherited types + const inheritedType = findInheritedType(association, sourceUmlClass, umlClasses); + if (inheritedType) + return inheritedType; + return undefined; }; exports.findAssociatedClass = findAssociatedClass; // Tests if source class can be linked to the target class via an association @@ -36,19 +43,18 @@ const isAssociated = (association, sourceUmlClass, targetUmlClass, targetParentU sourceUmlClass.imports.some((importLink) => importLink.absolutePath === targetUmlClass.absolutePath && importLink.classNames.some((importedClass) => // If a parent contract with no import alias - (association.parentUmlClassName !== undefined && + (association.targetUmlClassName === + targetUmlClass.name && association.parentUmlClassName === importedClass.className && - importedClass.className === - targetUmlClass.name && importedClass.alias == undefined) || // If a parent contract with import alias - (association.parentUmlClassName !== undefined && + (association.targetUmlClassName === + targetUmlClass.name && association.parentUmlClassName === - importedClass.alias && - importedClass.className === - targetUmlClass.name)))); + importedClass.alias)))); } + // No parent class in the association return ( // class is in the same source file (association.targetUmlClassName === targetUmlClass.name && @@ -70,6 +76,39 @@ const isAssociated = (association, sourceUmlClass, targetUmlClass, targetParentU importedClass.alias && importedClass.className === targetUmlClass.name)))); }; +const findInheritedType = (association, sourceUmlClass, umlClasses) => { + // Get all realized associations. + const parentAssociations = sourceUmlClass.getParentContracts(); + // For each parent association + for (const parentAssociation of parentAssociations) { + const parent = (0, exports.findAssociatedClass)(parentAssociation, sourceUmlClass, umlClasses); + if (!parent) + continue; + // For each struct on the parent + for (const structId of parent.structs) { + const structUmlClass = umlClasses.find((c) => c.id === structId); + if (!structUmlClass) + continue; + if (structUmlClass.name === association.targetUmlClassName) { + return structUmlClass; + } + } + // For each enum on the parent + for (const enumId of parent.enums) { + const enumUmlClass = umlClasses.find((c) => c.id === enumId); + if (!enumUmlClass) + continue; + if (enumUmlClass.name === association.targetUmlClassName) { + return enumUmlClass; + } + } + // Recursively look for inherited types + const targetClass = findInheritedType(association, parent, umlClasses); + if (targetClass) + return targetClass; + } + return undefined; +}; const findChainedImport = (association, sourceUmlClass, umlClasses, searchedRelativePaths) => { // Get all valid imports. That is, imports that do not explicitly import contracts or interfaces // or explicitly import the source class diff --git a/src/contracts/Associations.sol b/src/contracts/Associations.sol index a57f3e4e..4146484d 100644 --- a/src/contracts/Associations.sol +++ b/src/contracts/Associations.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.6; import './libraries/BigInt.sol'; import './libraries/Set.sol'; -import {ImportedFileLevelStruct, ImportedFileLevelStructAliased as IFLSA, ImportedTypesInContract} from './ImportedTypes.sol'; +import {ImportedFileLevelStruct, ImportedFileLevelStructAliased as IFLSA, ImportedTypesInContract, ImportedInterfaceWithStruct, ImportedParentContract, ImportedTypesAliasedContract as ITAS} from './ImportedTypes.sol'; interface ConstructorParamInterface { function someFunction() external returns (bool); @@ -196,6 +196,13 @@ library LibraryWithEnumLinked { } } +interface InterfaceWithStructLinked { + struct InterfaceStruct { + address tester; + uint256 counter; + } +} + struct StructOfStruct { address owner; bool flag; @@ -238,7 +245,11 @@ uint256 constant FileConstant = 5; abstract contract Associations is ContractInterface, ContractAbstract, - ContractConcrete + ContractConcrete, + InterfaceWithStructLinked, + ImportedInterfaceWithStruct, + ImportedParentContract, + ITAS { uint256 public someInt; @@ -251,6 +262,11 @@ abstract contract Associations is ImportedFileLevelStruct importedFileLevelStruct; IFLSA importedFileLevelStructAliased; ImportedTypesInContract.ImportedContractLevelStruct importedTypesInContract; + InterfaceStruct interfaceStruct; + ImportedInterfaceStruct importedInterfaceStruct; + GrandStruct grandStruct; + AliasedStruct aliasedStruct; + AliasedEnum aliasedEnum; FileLevelStorageEnum fileLevelEnum; LibraryWithStructLinked.LibStruct libStruct; diff --git a/src/contracts/ImportedTypes.sol b/src/contracts/ImportedTypes.sol index bae7f5ce..f344a246 100644 --- a/src/contracts/ImportedTypes.sol +++ b/src/contracts/ImportedTypes.sol @@ -22,3 +22,43 @@ contract ImportedTypesInContract { uint64 timestamp; } } + +interface ImportedInterfaceWithStruct { + struct ImportedInterfaceStruct { + address racer; + uint256 points; + } +} + +contract ImportedTypesInGrandContract { + struct GrandStruct { + uint256 total; + address user; + } + enum GrantEnum { + GRANT, + REVOKE + } +} + +contract ImportedParentContract is ImportedTypesInGrandContract { + struct ParentStruct { + bool flag; + uint256 counter; + } + enum ParentEnum { + ONE, + TWO + } +} + +contract ImportedTypesAliasedContract { + struct AliasedStruct { + string name; + string symbol; + } + enum AliasedEnum { + short, + long + } +} diff --git a/src/ts/associations.ts b/src/ts/associations.ts index 007f4767..87a0be78 100644 --- a/src/ts/associations.ts +++ b/src/ts/associations.ts @@ -24,15 +24,26 @@ export const findAssociatedClass = ( // If a link was found if (umlClass) return umlClass - // Could not find a link so now need to recursively look at imports of imports + // Could not find association so now need to recursively look at imports of imports // add to already recursively processed files to avoid getting stuck in circular imports searchedAbsolutePaths.push(sourceUmlClass.absolutePath) - return findChainedImport( + const importedType = findChainedImport( association, sourceUmlClass, umlClasses, searchedAbsolutePaths, ) + if (importedType) return importedType + + // Still could not find association so now need to recursively look for inherited types + const inheritedType = findInheritedType( + association, + sourceUmlClass, + umlClasses, + ) + if (inheritedType) return inheritedType + + return undefined } // Tests if source class can be linked to the target class via an association @@ -63,22 +74,21 @@ const isAssociated = ( importLink.classNames.some( (importedClass) => // If a parent contract with no import alias - (association.parentUmlClassName !== undefined && + (association.targetUmlClassName === + targetUmlClass.name && association.parentUmlClassName === importedClass.className && - importedClass.className === - targetUmlClass.name && importedClass.alias == undefined) || // If a parent contract with import alias - (association.parentUmlClassName !== undefined && + (association.targetUmlClassName === + targetUmlClass.name && association.parentUmlClassName === - importedClass.alias && - importedClass.className === - targetUmlClass.name), + importedClass.alias), ), ) ) } + // No parent class in the association return ( // class is in the same source file (association.targetUmlClassName === targetUmlClass.name && @@ -110,6 +120,47 @@ const isAssociated = ( ) } +const findInheritedType = ( + association: Association, + sourceUmlClass: UmlClass, + umlClasses: readonly UmlClass[], +): UmlClass | undefined => { + // Get all realized associations. + const parentAssociations = sourceUmlClass.getParentContracts() + + // For each parent association + for (const parentAssociation of parentAssociations) { + const parent = findAssociatedClass( + parentAssociation, + sourceUmlClass, + umlClasses, + ) + if (!parent) continue + // For each struct on the parent + for (const structId of parent.structs) { + const structUmlClass = umlClasses.find((c) => c.id === structId) + if (!structUmlClass) continue + if (structUmlClass.name === association.targetUmlClassName) { + return structUmlClass + } + } + // For each enum on the parent + for (const enumId of parent.enums) { + const enumUmlClass = umlClasses.find((c) => c.id === enumId) + if (!enumUmlClass) continue + if (enumUmlClass.name === association.targetUmlClassName) { + return enumUmlClass + } + } + + // Recursively look for inherited types + const targetClass = findInheritedType(association, parent, umlClasses) + if (targetClass) return targetClass + } + + return undefined +} + const findChainedImport = ( association: Association, sourceUmlClass: UmlClass,