diff --git a/src/agent/index.ts b/src/agent/index.ts index 0238b0d..2eec42a 100644 --- a/src/agent/index.ts +++ b/src/agent/index.ts @@ -92,6 +92,7 @@ import { create_proposal } from "../tools/squads_multisig/create_proposal"; import { approve_proposal } from "../tools/squads_multisig/approve_proposal"; import { execute_transaction } from "../tools/squads_multisig/execute_proposal"; import { reject_proposal } from "../tools/squads_multisig/reject_proposal"; +import { get_token_balance } from "../tools/get_token_balances"; /** * Main class for interacting with Solana blockchain @@ -163,22 +164,23 @@ export class SolanaAgentKit { return deploy_collection(this, options); } - async getBalance(token_address?: PublicKey): Promise< - | number - | { - sol: number; - tokens: Array<{ - tokenAddress: string; - name: string; - symbol: string; - balance: number; - decimals: number; - }>; - } - > { + async getBalance(token_address?: PublicKey): Promise { return get_balance(this, token_address); } + async getTokenBalances(wallet_address?: PublicKey): Promise<{ + sol: number; + tokens: Array<{ + tokenAddress: string; + name: string; + symbol: string; + balance: number; + decimals: number; + }>; + }> { + return get_token_balance(this, wallet_address); + } + async getBalanceOther( walletAddress: PublicKey, tokenAddress?: PublicKey, diff --git a/src/tools/get_token_balances.ts b/src/tools/get_token_balances.ts index 00e6310..01f2f2d 100644 --- a/src/tools/get_token_balances.ts +++ b/src/tools/get_token_balances.ts @@ -11,55 +11,49 @@ import { getTokenMetadata } from "../utils/tokenMetadata"; */ export async function get_token_balance( agent: SolanaAgentKit, - token_address?: PublicKey, -): Promise< - | number - | { - sol: number; - tokens: Array<{ - tokenAddress: string; - name: string; - symbol: string; - balance: number; - decimals: number; - }>; - } -> { - if (!token_address) { - const [lamportsBalance, tokenAccountData] = await Promise.all([ - agent.connection.getBalance(agent.wallet_address), - agent.connection.getParsedTokenAccountsByOwner(agent.wallet_address, { + walletAddress?: PublicKey, +): Promise<{ + sol: number; + tokens: Array<{ + tokenAddress: string; + name: string; + symbol: string; + balance: number; + decimals: number; + }>; +}> { + const [lamportsBalance, tokenAccountData] = await Promise.all([ + agent.connection.getBalance(walletAddress ?? agent.wallet_address), + agent.connection.getParsedTokenAccountsByOwner( + walletAddress ?? agent.wallet_address, + { programId: TOKEN_PROGRAM_ID, - }), - ]); + }, + ), + ]); - const removedZeroBalance = tokenAccountData.value.filter( - (v) => v.account.data.parsed.info.tokenAmount.uiAmount !== 0, - ); + const removedZeroBalance = tokenAccountData.value.filter( + (v) => v.account.data.parsed.info.tokenAmount.uiAmount !== 0, + ); - const tokenBalances = await Promise.all( - removedZeroBalance.map(async (v) => { - const mint = v.account.data.parsed.info.mint; - const mintInfo = await getTokenMetadata(agent.connection, mint); - return { - tokenAddress: mint, - name: mintInfo.name ?? "", - symbol: mintInfo.symbol ?? "", - balance: v.account.data.parsed.info.tokenAmount.uiAmount as number, - decimals: v.account.data.parsed.info.tokenAmount.decimals as number, - }; - }), - ); + const tokenBalances = await Promise.all( + removedZeroBalance.map(async (v) => { + const mint = v.account.data.parsed.info.mint; + const mintInfo = await getTokenMetadata(agent.connection, mint); + return { + tokenAddress: mint, + name: mintInfo.name ?? "", + symbol: mintInfo.symbol ?? "", + balance: v.account.data.parsed.info.tokenAmount.uiAmount as number, + decimals: v.account.data.parsed.info.tokenAmount.decimals as number, + }; + }), + ); - const solBalance = lamportsBalance / LAMPORTS_PER_SOL; + const solBalance = lamportsBalance / LAMPORTS_PER_SOL; - return { - sol: solBalance, - tokens: tokenBalances, - }; - } - - const token_account = - await agent.connection.getTokenAccountBalance(token_address); - return token_account.value.uiAmount || 0; + return { + sol: solBalance, + tokens: tokenBalances, + }; }