From 272f77f851071eef212873dc39c95a5dde6bfd44 Mon Sep 17 00:00:00 2001 From: michaelessiet Date: Fri, 10 Jan 2025 17:31:08 +0100 Subject: [PATCH] feat: token balances --- src/actions/balance.ts | 5 ++- src/actions/index.ts | 2 + src/actions/tokenBalances.ts | 80 +++++++++++++++++++++++++++++++++ src/tools/get_balance.ts | 53 +++------------------- src/tools/get_token_balances.ts | 65 +++++++++++++++++++++++++++ test/agent_sdks/vercel_ai.ts | 1 - 6 files changed, 156 insertions(+), 50 deletions(-) create mode 100644 src/actions/tokenBalances.ts create mode 100644 src/tools/get_token_balances.ts diff --git a/src/actions/balance.ts b/src/actions/balance.ts index 1ee1b56..437e479 100644 --- a/src/actions/balance.ts +++ b/src/actions/balance.ts @@ -1,6 +1,6 @@ import { PublicKey } from "@solana/web3.js"; -import { Action } from "../types/action"; -import { SolanaAgentKit } from "../agent"; +import type { Action } from "../types/action"; +import type { SolanaAgentKit } from "../agent"; import { z } from "zod"; import { get_balance } from "../tools"; @@ -54,6 +54,7 @@ const balanceAction: Action = { return { status: "success", balance: balance, + token: input.tokenAddress || "SOL", }; }, }; diff --git a/src/actions/index.ts b/src/actions/index.ts index c974209..9eec07d 100644 --- a/src/actions/index.ts +++ b/src/actions/index.ts @@ -30,9 +30,11 @@ import launchPumpfunTokenAction from "./launchPumpfunToken"; import getWalletAddressAction from "./getWalletAddress"; import flashOpenTradeAction from "./flashOpenTrade"; import flashCloseTradeAction from "./flashCloseTrade"; +import tokenBalancesAction from "./tokenBalances"; export const ACTIONS = { WALLET_ADDRESS_ACTION: getWalletAddressAction, + TOKEN_BALANCES_ACTION: tokenBalancesAction, DEPLOY_TOKEN_ACTION: deployTokenAction, BALANCE_ACTION: balanceAction, TRANSFER_ACTION: transferAction, diff --git a/src/actions/tokenBalances.ts b/src/actions/tokenBalances.ts new file mode 100644 index 0000000..960fca1 --- /dev/null +++ b/src/actions/tokenBalances.ts @@ -0,0 +1,80 @@ +import { PublicKey } from "@solana/web3.js"; +import type { Action } from "../types/action"; +import type { SolanaAgentKit } from "../agent"; +import { z } from "zod"; +import { get_token_balance } from "../tools/get_token_balances"; + +const tokenBalancesAction: Action = { + name: "TOKEN_BALANCE_ACTION", + similes: [ + "check token balances", + "get wallet token balances", + "view token balances", + "show token balances", + "check token balance", + ], + description: `Get the token balances of a Solana wallet. + If you want to get the balance of your wallet, you don't need to provide the wallet address.`, + examples: [ + [ + { + input: {}, + output: { + status: "success", + balance: { + sol: 100, + tokens: [ + { + tokenAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", + name: "USD Coin", + symbol: "USDC", + balance: 100, + decimals: 9, + }, + ], + }, + }, + explanation: "Get token balances of the wallet", + }, + ], + [ + { + input: { + walletAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", + }, + output: { + status: "success", + balance: { + sol: 100, + tokens: [ + { + tokenAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", + name: "USD Coin", + symbol: "USDC", + balance: 100, + decimals: 9, + }, + ], + }, + }, + explanation: "Get address token balance", + }, + ], + ], + schema: z.object({ + walletAddress: z.string().optional(), + }), + handler: async (agent: SolanaAgentKit, input: Record) => { + const balance = await get_token_balance( + agent, + input.tokenAddress && new PublicKey(input.tokenAddress), + ); + + return { + status: "success", + balance: balance, + }; + }, +}; + +export default tokenBalancesAction; diff --git a/src/tools/get_balance.ts b/src/tools/get_balance.ts index 3e30021..f7a6a1e 100644 --- a/src/tools/get_balance.ts +++ b/src/tools/get_balance.ts @@ -1,7 +1,5 @@ -import { LAMPORTS_PER_SOL, PublicKey } from "@solana/web3.js"; -import { SolanaAgentKit } from "../index"; -import { TOKEN_PROGRAM_ID } from "@solana/spl-token"; -import { getTokenMetadata } from "../utils/tokenMetadata"; +import { LAMPORTS_PER_SOL, type PublicKey } from "@solana/web3.js"; +import type { SolanaAgentKit } from "../index"; /** * Get the balance of SOL or an SPL token for the agent's wallet @@ -12,51 +10,12 @@ import { getTokenMetadata } from "../utils/tokenMetadata"; export async function get_balance( agent: SolanaAgentKit, token_address?: PublicKey, -): Promise< - | number - | { - sol: number; - tokens: Array<{ - tokenAddress: string; - name: string; - symbol: string; - balance: number; - decimals: number; - }>; - } -> { +): Promise { if (!token_address) { - const [lamportsBalance, tokenAccountData] = await Promise.all([ - agent.connection.getBalance(agent.wallet_address), - agent.connection.getParsedTokenAccountsByOwner(agent.wallet_address, { - programId: TOKEN_PROGRAM_ID, - }), - ]); - - const removedZeroBalance = tokenAccountData.value.filter( - (v) => v.account.data.parsed.info.tokenAmount.uiAmount !== 0, + return ( + (await agent.connection.getBalance(agent.wallet_address)) / + LAMPORTS_PER_SOL ); - - const tokenBalances = await Promise.all( - removedZeroBalance.map(async (v) => { - const mint = v.account.data.parsed.info.mint; - const mintInfo = await getTokenMetadata(agent.connection, mint); - return { - tokenAddress: mint, - name: mintInfo.name ?? "", - symbol: mintInfo.symbol ?? "", - balance: v.account.data.parsed.info.tokenAmount.uiAmount as number, - decimals: v.account.data.parsed.info.tokenAmount.decimals as number, - }; - }), - ); - - const solBalance = lamportsBalance / LAMPORTS_PER_SOL; - - return { - sol: solBalance, - tokens: tokenBalances, - }; } const token_account = diff --git a/src/tools/get_token_balances.ts b/src/tools/get_token_balances.ts new file mode 100644 index 0000000..00e6310 --- /dev/null +++ b/src/tools/get_token_balances.ts @@ -0,0 +1,65 @@ +import { LAMPORTS_PER_SOL, type PublicKey } from "@solana/web3.js"; +import type { SolanaAgentKit } from "../index"; +import { TOKEN_PROGRAM_ID } from "@solana/spl-token"; +import { getTokenMetadata } from "../utils/tokenMetadata"; + +/** + * Get the token balances of a Solana wallet + * @param agent - SolanaAgentKit instance + * @param token_address - Optional SPL token mint address. If not provided, returns SOL balance + * @returns Promise resolving to the balance as an object containing sol balance and token balances with their respective mints, symbols, names and decimals + */ +export async function get_token_balance( + agent: SolanaAgentKit, + token_address?: PublicKey, +): Promise< + | number + | { + sol: number; + tokens: Array<{ + tokenAddress: string; + name: string; + symbol: string; + balance: number; + decimals: number; + }>; + } +> { + if (!token_address) { + const [lamportsBalance, tokenAccountData] = await Promise.all([ + agent.connection.getBalance(agent.wallet_address), + agent.connection.getParsedTokenAccountsByOwner(agent.wallet_address, { + programId: TOKEN_PROGRAM_ID, + }), + ]); + + const removedZeroBalance = tokenAccountData.value.filter( + (v) => v.account.data.parsed.info.tokenAmount.uiAmount !== 0, + ); + + const tokenBalances = await Promise.all( + removedZeroBalance.map(async (v) => { + const mint = v.account.data.parsed.info.mint; + const mintInfo = await getTokenMetadata(agent.connection, mint); + return { + tokenAddress: mint, + name: mintInfo.name ?? "", + symbol: mintInfo.symbol ?? "", + balance: v.account.data.parsed.info.tokenAmount.uiAmount as number, + decimals: v.account.data.parsed.info.tokenAmount.decimals as number, + }; + }), + ); + + const solBalance = lamportsBalance / LAMPORTS_PER_SOL; + + return { + sol: solBalance, + tokens: tokenBalances, + }; + } + + const token_account = + await agent.connection.getTokenAccountBalance(token_address); + return token_account.value.uiAmount || 0; +} diff --git a/test/agent_sdks/vercel_ai.ts b/test/agent_sdks/vercel_ai.ts index 77fda22..bf1585e 100644 --- a/test/agent_sdks/vercel_ai.ts +++ b/test/agent_sdks/vercel_ai.ts @@ -95,7 +95,6 @@ async function runChatMode() { ); const tools = createVercelAITools(solanaAgent); - console.log(tools); const rl = readline.createInterface({ input: process.stdin,