diff --git a/src/actions/index.ts b/src/actions/index.ts index 63ec7c5..29cf233 100644 --- a/src/actions/index.ts +++ b/src/actions/index.ts @@ -1,3 +1,4 @@ +import tokenBalancesAction from "./tokenBalances"; import deployTokenAction from "./metaplex/deployToken"; import balanceAction from "./solana/balance"; import transferAction from "./solana/transfer"; @@ -61,6 +62,7 @@ import updateDriftVaultDelegateAction from "./drift/updateDriftVaultDelegate"; export const ACTIONS = { WALLET_ADDRESS_ACTION: getWalletAddressAction, + TOKEN_BALANCES_ACTION: tokenBalancesAction, DEPLOY_TOKEN_ACTION: deployTokenAction, BALANCE_ACTION: balanceAction, TRANSFER_ACTION: transferAction, diff --git a/src/actions/solana/balance.ts b/src/actions/solana/balance.ts index 1009033..b343fef 100644 --- a/src/actions/solana/balance.ts +++ b/src/actions/solana/balance.ts @@ -1,6 +1,6 @@ import { PublicKey } from "@solana/web3.js"; -import { Action } from "../../types/action"; -import { SolanaAgentKit } from "../../agent"; +import type { Action } from "../../types/action"; +import type { SolanaAgentKit } from "../../agent"; import { z } from "zod"; import { get_balance } from "../../tools"; diff --git a/src/actions/tokenBalances.ts b/src/actions/tokenBalances.ts new file mode 100644 index 0000000..f21ccbf --- /dev/null +++ b/src/actions/tokenBalances.ts @@ -0,0 +1,80 @@ +import { PublicKey } from "@solana/web3.js"; +import type { Action } from "../types/action"; +import type { SolanaAgentKit } from "../agent"; +import { z } from "zod"; +import { get_token_balance } from "../tools"; + +const tokenBalancesAction: Action = { + name: "TOKEN_BALANCE_ACTION", + similes: [ + "check token balances", + "get wallet token balances", + "view token balances", + "show token balances", + "check token balance", + ], + description: `Get the token balances of a Solana wallet. + If you want to get the balance of your wallet, you don't need to provide the wallet address.`, + examples: [ + [ + { + input: {}, + output: { + status: "success", + balance: { + sol: 100, + tokens: [ + { + tokenAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", + name: "USD Coin", + symbol: "USDC", + balance: 100, + decimals: 9, + }, + ], + }, + }, + explanation: "Get token balances of the wallet", + }, + ], + [ + { + input: { + walletAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", + }, + output: { + status: "success", + balance: { + sol: 100, + tokens: [ + { + tokenAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", + name: "USD Coin", + symbol: "USDC", + balance: 100, + decimals: 9, + }, + ], + }, + }, + explanation: "Get address token balance", + }, + ], + ], + schema: z.object({ + walletAddress: z.string().optional(), + }), + handler: async (agent: SolanaAgentKit, input) => { + const balance = await get_token_balance( + agent, + input.tokenAddress && new PublicKey(input.tokenAddress), + ); + + return { + status: "success", + balance: balance, + }; + }, +}; + +export default tokenBalancesAction; diff --git a/src/agent/index.ts b/src/agent/index.ts index 45f085f..d087931 100644 --- a/src/agent/index.ts +++ b/src/agent/index.ts @@ -97,6 +97,7 @@ import { withdrawFromDriftUserAccount, withdrawFromDriftVault, updateVaultDelegate, + get_token_balance, } from "../tools"; import { Config, @@ -189,6 +190,19 @@ export class SolanaAgentKit { return get_balance(this, token_address); } + async getTokenBalances(wallet_address?: PublicKey): Promise<{ + sol: number; + tokens: Array<{ + tokenAddress: string; + name: string; + symbol: string; + balance: number; + decimals: number; + }>; + }> { + return get_token_balance(this, wallet_address); + } + async getBalanceOther( walletAddress: PublicKey, tokenAddress?: PublicKey, diff --git a/src/tools/solana/get_token_balances.ts b/src/tools/solana/get_token_balances.ts new file mode 100644 index 0000000..47272db --- /dev/null +++ b/src/tools/solana/get_token_balances.ts @@ -0,0 +1,59 @@ +import { LAMPORTS_PER_SOL, type PublicKey } from "@solana/web3.js"; +import type { SolanaAgentKit } from "../../index"; +import { TOKEN_PROGRAM_ID } from "@solana/spl-token"; +import { getTokenMetadata } from "../../utils/tokenMetadata"; + +/** + * Get the token balances of a Solana wallet + * @param agent - SolanaAgentKit instance + * @param token_address - Optional SPL token mint address. If not provided, returns SOL balance + * @returns Promise resolving to the balance as an object containing sol balance and token balances with their respective mints, symbols, names and decimals + */ +export async function get_token_balance( + agent: SolanaAgentKit, + walletAddress?: PublicKey, +): Promise<{ + sol: number; + tokens: Array<{ + tokenAddress: string; + name: string; + symbol: string; + balance: number; + decimals: number; + }>; +}> { + const [lamportsBalance, tokenAccountData] = await Promise.all([ + agent.connection.getBalance(walletAddress ?? agent.wallet_address), + agent.connection.getParsedTokenAccountsByOwner( + walletAddress ?? agent.wallet_address, + { + programId: TOKEN_PROGRAM_ID, + }, + ), + ]); + + const removedZeroBalance = tokenAccountData.value.filter( + (v) => v.account.data.parsed.info.tokenAmount.uiAmount !== 0, + ); + + const tokenBalances = await Promise.all( + removedZeroBalance.map(async (v) => { + const mint = v.account.data.parsed.info.mint; + const mintInfo = await getTokenMetadata(agent.connection, mint); + return { + tokenAddress: mint, + name: mintInfo.name ?? "", + symbol: mintInfo.symbol ?? "", + balance: v.account.data.parsed.info.tokenAmount.uiAmount as number, + decimals: v.account.data.parsed.info.tokenAmount.decimals as number, + }; + }), + ); + + const solBalance = lamportsBalance / LAMPORTS_PER_SOL; + + return { + sol: solBalance, + tokens: tokenBalances, + }; +} diff --git a/src/tools/solana/index.ts b/src/tools/solana/index.ts index f9681a4..a7b6de9 100644 --- a/src/tools/solana/index.ts +++ b/src/tools/solana/index.ts @@ -4,3 +4,4 @@ export * from "./close_empty_token_accounts"; export * from "./transfer"; export * from "./get_balance"; export * from "./get_balance_other"; +export * from "./get_token_balances"; diff --git a/src/utils/tokenMetadata.ts b/src/utils/tokenMetadata.ts new file mode 100644 index 0000000..514a208 --- /dev/null +++ b/src/utils/tokenMetadata.ts @@ -0,0 +1,83 @@ +import { Connection, PublicKey } from "@solana/web3.js"; + +export async function getTokenMetadata( + connection: Connection, + tokenMint: string, +) { + const METADATA_PROGRAM_ID = new PublicKey( + "metaqbxxUerdq28cj1RbAWkYQm3ybzjb6a8bt518x1s", + ); + + const [metadataPDA] = PublicKey.findProgramAddressSync( + [ + Buffer.from("metadata"), + METADATA_PROGRAM_ID.toBuffer(), + new PublicKey(tokenMint).toBuffer(), + ], + METADATA_PROGRAM_ID, + ); + + const metadata = await connection.getAccountInfo(metadataPDA); + if (!metadata?.data) { + throw new Error("Metadata not found"); + } + + let offset = 1 + 32 + 32; // key + update auth + mint + const data = metadata.data; + const decoder = new TextDecoder(); + + // Read variable length strings + const readString = () => { + let nameLength = data[offset]; + + while (nameLength === 0) { + offset++; + nameLength = data[offset]; + if (offset >= data.length) { + return null; + } + } + + offset++; + const name = decoder + .decode(data.slice(offset, offset + nameLength)) + // @eslint-disable-next-line no-control-regex + .replace(new RegExp(String.fromCharCode(0), "g"), ""); + offset += nameLength; + return name; + }; + + const name = readString(); + const symbol = readString(); + const uri = readString(); + + // Read remaining data + const sellerFeeBasisPoints = data.readUInt16LE(offset); + offset += 2; + + let creators: + | { address: PublicKey; verified: boolean; share: number }[] + | null = null; + if (data[offset] === 1) { + offset++; + const numCreators = data[offset]; + offset++; + creators = [...Array(numCreators)].map(() => { + const creator = { + address: new PublicKey(data.slice(offset, offset + 32)), + verified: data[offset + 32] === 1, + share: data[offset + 33], + }; + offset += 34; + return creator; + }); + } + + return { + name, + symbol, + uri, + sellerFeeBasisPoints, + creators, + }; +} diff --git a/test/agent_sdks/vercel_ai.ts b/test/agent_sdks/vercel_ai.ts index 77fda22..bf1585e 100644 --- a/test/agent_sdks/vercel_ai.ts +++ b/test/agent_sdks/vercel_ai.ts @@ -95,7 +95,6 @@ async function runChatMode() { ); const tools = createVercelAITools(solanaAgent); - console.log(tools); const rl = readline.createInterface({ input: process.stdin,