diff --git a/app/components/UI/Stake/components/StakeButton/index.tsx b/app/components/UI/Stake/components/StakeButton/index.tsx index ca578bef3bd..163e88900bc 100644 --- a/app/components/UI/Stake/components/StakeButton/index.tsx +++ b/app/components/UI/Stake/components/StakeButton/index.tsx @@ -25,6 +25,8 @@ import { strings } from '../../../../../../locales/i18n'; import { RootState } from '../../../../../reducers'; import useStakingEligibility from '../../hooks/useStakingEligibility'; import { StakeSDKProvider } from '../../sdk/stakeSdkProvider'; +import useStakingChain from '../../hooks/useStakingChain'; +import Engine from '../../../../../core/Engine'; interface StakeButtonProps { asset: TokenI; @@ -37,12 +39,15 @@ const StakeButtonContent = ({ asset }: StakeButtonProps) => { const browserTabs = useSelector((state: RootState) => state.browser.tabs); const chainId = useSelector(selectChainId); - - const { refreshPooledStakingEligibility } = useStakingEligibility(); + const { isEligible } = useStakingEligibility(); + const { isStakingSupportedChain } = useStakingChain(); const onStakeButtonPress = async () => { - const { isEligible } = await refreshPooledStakingEligibility(); - if (isPooledStakingFeatureEnabled() && isEligible) { + if (!isStakingSupportedChain) { + const { NetworkController } = Engine.context; + await NetworkController.setActiveNetwork('mainnet'); + } + if (isEligible) { navigation.navigate('StakeScreens', { screen: Routes.STAKING.STAKE }); } else { const existingStakeTab = browserTabs.find((tab: BrowserTab) => diff --git a/app/components/UI/Tokens/index.test.tsx b/app/components/UI/Tokens/index.test.tsx index 13cbb2a6f8e..70bc0545537 100644 --- a/app/components/UI/Tokens/index.test.tsx +++ b/app/components/UI/Tokens/index.test.tsx @@ -250,10 +250,14 @@ jest.mock('../../UI/Stake/hooks/useStakingEligibility', () => ({ })), })); -jest.mock('../Stake/hooks/useStakingChain', () => ({ - useStakingChainByChainId: () => ({ +jest.mock('../../UI/Stake/hooks/useStakingChain', () => ({ + __esModule: true, + default: jest.fn(() => ({ isStakingSupportedChain: true, - }), + })), + useStakingChainByChainId: jest.fn(() => ({ + isStakingSupportedChain: true, + })), })); const Stack = createStackNavigator();