Skip to content

Commit

Permalink
refactor: override NTTManager _transferEntryPoint function
Browse files Browse the repository at this point in the history
  • Loading branch information
0xIryna committed Jan 22, 2025
1 parent ba0e968 commit 5cb74cd
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 48 deletions.
13 changes: 13 additions & 0 deletions .github/example-native-token-transfers.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
diff --git a/evm/src/NttManager/NttManager.sol b/evm/src/NttManager/NttManager.sol
index 7ac021b..60e4aee 100644
--- a/evm/src/NttManager/NttManager.sol
+++ b/evm/src/NttManager/NttManager.sol
@@ -383,7 +383,7 @@ contract NttManager is INttManager, RateLimiter, ManagerBase {
bytes32 refundAddress,
bool shouldQueue,
bytes memory transceiverInstructions
- ) internal returns (uint64) {
+ ) internal virtual returns (uint64) {
if (amount == 0) {
revert ZeroAmount();
}
5 changes: 5 additions & 0 deletions .github/workflows/test-fork.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ jobs:
with:
submodules: recursive

- name: Apply patch
run: |
cd lib/example-native-token-transfers
git apply ../../.github/example-native-token-transfers.diff
- name: Setup Node
uses: actions/setup-node@v4

Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/test-gas.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ jobs:
with:
submodules: recursive

- name: Apply patch
run: |
cd lib/example-native-token-transfers
git apply ../../.github/example-native-token-transfers.diff
- name: Setup Node
uses: actions/setup-node@v4

Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/test-sizes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ jobs:
with:
submodules: recursive

- name: Apply patch
run: |
cd lib/example-native-token-transfers
git apply ../../.github/example-native-token-transfers.diff
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@v1

Expand Down
2 changes: 1 addition & 1 deletion foundry.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ via_ir = false
[profile.production]
build_info = true
optimizer = true
optimizer_runs = 800
optimizer_runs = 1500
sizes = true
via_ir = true

Expand Down
118 changes: 75 additions & 43 deletions src/Portal.sol
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,85 @@ abstract contract Portal is NttManagerNoRateLimiting, IPortal {
bytes32 recipient_,
bytes32 refundAddress_
) external payable returns (bytes32 messageId_) {
if (amount_ == 0) revert ZeroAmount();
_verifyTransferArgs(amount_, destinationToken_, recipient_, refundAddress_);

if (sourceToken_ == address(0)) revert ZeroSourceToken();
if (destinationToken_ == bytes32(0)) revert ZeroDestinationToken();
if (recipient_ == bytes32(0)) revert InvalidRecipient();
if (refundAddress_ == bytes32(0)) revert InvalidRefundAddress();
if (!supportedDestinationToken[destinationChainId_][destinationToken_])
revert UnsupportedDestinationToken(destinationChainId_, destinationToken_);

IERC20 mToken_ = IERC20(token);
uint256 balanceBefore = mToken_.balanceOf(address(this));

// transfer source token from the sender
IERC20(sourceToken_).transferFrom(msg.sender, address(this), amount_);

// if the source token isn't M token, unwrap it
if (sourceToken_ != token) {
if (sourceToken_ != address(mToken_)) {
amount_ = IWrappedMTokenLike(sourceToken_).unwrap(address(this), amount_);
}

// account for potential rounding errors when transferring between earners and non-earners
amount_ = mToken_.balanceOf(address(this)) - balanceBefore;

(messageId_, ) = _transferMToken(
amount_,
sourceToken_,
destinationToken_,
destinationChainId_,
recipient_,
refundAddress_
);
}

/* ============ Internal/Private Interactive Functions ============ */

/// @dev Called from NTTManager `transfer` function to transfer M token
/// Overridden to reduce code duplication, optimize gas cost and prevent Yul stack too deep
function _transferEntryPoint(
uint256 amount_,
uint16 destinationChainId_,
bytes32 recipient_,
bytes32 refundAddress_,
bool, // shouldQueue_
bytes memory // transceiverInstructions_
) internal override returns (uint64 sequence_) {
bytes32 destinationToken_ = destinationMToken[destinationChainId_];

_verifyTransferArgs(amount_, destinationToken_, recipient_, refundAddress_);

IERC20 mToken_ = IERC20(token);
uint256 balanceBefore = mToken_.balanceOf(address(this));

// transfer M token from the sender
mToken_.transferFrom(msg.sender, address(this), amount_);

// account for potential rounding errors when transferring between earners and non-earners
amount_ = mToken_.balanceOf(address(this)) - balanceBefore;

(, sequence_) = _transferMToken(
amount_,
token,
destinationToken_,
destinationChainId_,
recipient_,
refundAddress_
);
}

function _transferMToken(
uint256 amount_,
address sourceToken_,
bytes32 destinationToken_,
uint16 destinationChainId_,
bytes32 recipient_,
bytes32 refundAddress_
) private returns (bytes32 messageId_, uint64 sequence_) {
// NOTE: the following code has been adapted from NTT manager `transfer` or `_transferEntryPoint` functions.
// We cannot call those functions directly here as they attempt to transfer M Token from the msg.sender.

_burnOrLock(amount_);

uint64 sequence_ = _useMessageSequence();
sequence_ = _useMessageSequence();
uint128 index_ = _currentIndex();

TransceiverStructs.NttManagerMessage memory message_;
Expand Down Expand Up @@ -167,43 +224,6 @@ abstract contract Portal is NttManagerNoRateLimiting, IPortal {
);
emit TransferSent(messageId_);
}
/* ============ Internal/Private Interactive Functions ============ */

/// @dev Called from NTT manager during M Token transfer to customize additional payload.
/// Adds M Token index and empty Wrapper Address to the NTT payload.
function _prepareNativeTokenTransfer(
TrimmedAmount amount_,
bytes32 recipient_,
uint16 destinationChainId_,
uint64 sequence_,
address sender_,
bytes32 // refundAddress
) internal override returns (TransceiverStructs.NativeTokenTransfer memory nativeTokenTransfer_) {
uint128 index_ = _currentIndex();
bytes32 destinationMToken_ = destinationMToken[destinationChainId_];
bytes32 messageId_;
(nativeTokenTransfer_, , messageId_) = _encodeTokenTransfer(
amount_,
index_,
recipient_,
destinationMToken_,
destinationChainId_,
sequence_,
sender_
);

uint256 untrimmedAmount_ = amount_.untrim(tokenDecimals());
emit MTokenSent(
destinationChainId_,
token,
destinationMToken_,
messageId_,
sender_,
recipient_,
untrimmedAmount_,
index_
);
}

function _encodeTokenTransfer(
TrimmedAmount amount_,
Expand Down Expand Up @@ -349,6 +369,18 @@ abstract contract Portal is NttManagerNoRateLimiting, IPortal {
if (evmChainId_ != block.chainid) revert InvalidFork(evmChainId_, block.chainid);
}

function _verifyTransferArgs(
uint256 amount_,
bytes32 destinationToken_,
bytes32 recipient_,
bytes32 refundAddress_
) private view {
if (amount_ == 0) revert ZeroAmount();
if (destinationToken_ == bytes32(0)) revert ZeroDestinationToken();
if (recipient_ == bytes32(0)) revert InvalidRecipient();
if (refundAddress_ == bytes32(0)) revert InvalidRefundAddress();
}

/**
* @dev HubPortal: unlocks and transfers `amount_` M tokens to `recipient_`.
* SpokePortal: mints `amount_` M tokens to `recipient_`.
Expand Down
6 changes: 6 additions & 0 deletions test/fork/HubPortalFork.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ import { IMToken } from "../../lib/protocol/src/interfaces/IMToken.sol";
import { IHubPortal } from "../../src/interfaces/IHubPortal.sol";
import { IPortal } from "../../src/interfaces/IPortal.sol";
import { IRegistrarLike } from "../../src/interfaces/IRegistrarLike.sol";
import { TypeConverter } from "../../src/libs/TypeConverter.sol";

import { ForkTestBase } from "./ForkTestBase.t.sol";

contract HubPortalForkTests is ForkTestBase {
using TypeConverter for *;

function setUp() public override {
super.setUp();
_configurePortals();
Expand All @@ -31,6 +34,9 @@ contract HubPortalForkTests is ForkTestBase {

uint128 mainnetIndex_ = IContinuousIndexing(_MAINNET_M_TOKEN).currentIndex();

vm.prank(_DEPLOYER);
IPortal(_hubPortal).setDestinationMToken(_BASE_WORMHOLE_CHAIN_ID, _baseSpokeMToken.toBytes32());

vm.startPrank(_mHolder);
vm.recordLogs();

Expand Down
12 changes: 12 additions & 0 deletions test/fork/SpokePortalFork.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ pragma solidity 0.8.26;
import { IERC20 } from "../../lib/common/src/interfaces/IERC20.sol";
import { IContinuousIndexing } from "../../lib/protocol/src/interfaces/IContinuousIndexing.sol";

import { IPortal } from "../../src/interfaces/IPortal.sol";
import { ISpokePortal } from "../../src/interfaces/ISpokePortal.sol";
import { TypeConverter } from "../../src/libs/TypeConverter.sol";

import { ForkTestBase } from "./ForkTestBase.t.sol";

contract SpokePortalForkTests is ForkTestBase {
using TypeConverter for *;
uint256 internal _amount;
uint128 internal _mainnetIndex;

Expand All @@ -23,6 +26,9 @@ contract SpokePortalForkTests is ForkTestBase {
function testFork_transferToHubPortal() external {
_beforeTest();

vm.prank(_DEPLOYER);
IPortal(_baseSpokePortal).setDestinationMToken(_MAINNET_WORMHOLE_CHAIN_ID, _MAINNET_M_TOKEN.toBytes32());

vm.startPrank(_mHolder);

IERC20(_baseSpokeMToken).approve(_baseSpokePortal, _amount);
Expand Down Expand Up @@ -56,6 +62,9 @@ contract SpokePortalForkTests is ForkTestBase {
function testFork_transferBetweenSpokePortals() external {
_beforeTest();

vm.prank(_DEPLOYER);
IPortal(_baseSpokePortal).setDestinationMToken(_OPTIMISM_WORMHOLE_CHAIN_ID, _optimismSpokeMToken.toBytes32());

vm.startPrank(_mHolder);

IERC20(_baseSpokeMToken).approve(_baseSpokePortal, _amount);
Expand Down Expand Up @@ -90,6 +99,9 @@ contract SpokePortalForkTests is ForkTestBase {
// First, transfer M tokens to the Spoke chain.
vm.selectFork(_mainnetForkId);

vm.prank(_DEPLOYER);
IPortal(_hubPortal).setDestinationMToken(_BASE_WORMHOLE_CHAIN_ID, _baseSpokeMToken.toBytes32());

_mainnetIndex = IContinuousIndexing(_MAINNET_M_TOKEN).currentIndex();

vm.startPrank(_mHolder);
Expand Down
11 changes: 11 additions & 0 deletions test/fork/SpokeVaultFork.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ pragma solidity 0.8.26;

import { IERC20 } from "../../lib/common/src/interfaces/IERC20.sol";

import { IPortal } from "../../src/interfaces/IPortal.sol";
import { TypeConverter } from "../../src/libs/TypeConverter.sol";

import { ForkTestBase } from "./ForkTestBase.t.sol";

contract SpokeVaultForkTests is ForkTestBase {
using TypeConverter for *;

uint256 internal _amount;

function setUp() public override {
Expand All @@ -19,6 +24,9 @@ contract SpokeVaultForkTests is ForkTestBase {
function testFork_transferExcessM() external {
_beforeTest();

vm.prank(_DEPLOYER);
IPortal(_baseSpokePortal).setDestinationMToken(_MAINNET_WORMHOLE_CHAIN_ID, _MAINNET_M_TOKEN.toBytes32());

vm.startPrank(_mHolder);

// Then, transfer excess M tokens to the Hub chain.
Expand Down Expand Up @@ -49,6 +57,9 @@ contract SpokeVaultForkTests is ForkTestBase {

vm.selectFork(_mainnetForkId);

vm.prank(_DEPLOYER);
IPortal(_hubPortal).setDestinationMToken(_BASE_WORMHOLE_CHAIN_ID, _MAINNET_M_TOKEN.toBytes32());

vm.startPrank(_mHolder);

// First, transfer M tokens to the Spoke chain
Expand Down
1 change: 1 addition & 0 deletions test/unit/HubPortal.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ contract HubPortalTests is UnitTestBase {
_portal = HubPortal(_createProxy(address(implementation_)));

_initializePortal(_portal);
_portal.setDestinationMToken(_REMOTE_CHAIN_ID, _remoteMToken);
}

/* ============ initialState ============ */
Expand Down
5 changes: 1 addition & 4 deletions test/unit/Portal.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ contract PortalTests is UnitTestBase {
);
_portal = PortalHarness(_createProxy(address(implementation_)));
_initializePortal(_portal);
_portal.setDestinationMToken(_REMOTE_CHAIN_ID, _remoteMToken);
}

/* ============ constructor ============ */
Expand All @@ -68,15 +69,11 @@ contract PortalTests is UnitTestBase {

function test_transfer_zeroAmount() external {
vm.expectRevert(INttManager.ZeroAmount.selector);

vm.prank(_alice);
_portal.transfer(0, _REMOTE_CHAIN_ID, _alice.toBytes32());
}

function test_transfer_zeroRecipient() external {
vm.expectRevert(INttManager.InvalidRecipient.selector);

vm.prank(_alice);
_portal.transfer(1_000e6, _REMOTE_CHAIN_ID, bytes32(0));
}

Expand Down
2 changes: 2 additions & 0 deletions test/unit/SpokePortal.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ contract SpokePortalTests is UnitTestBase {
_portal = SpokePortal(_createProxy(address(implementation_)));

_initializePortal(_portal);

_portal.setDestinationMToken(_REMOTE_CHAIN_ID, _remoteMToken);
}

/* ============ initialState ============ */
Expand Down

0 comments on commit 5cb74cd

Please sign in to comment.