feat: get all token balances (#113)

# Pull Request Description

## Related Issue
This doesn't fix any pre-existing issue, but is an attempt to claim the
bring your own idea bounty.

## Changes Made
This PR adds the following changes:
<!-- List the key changes made in this PR -->
- when token address is not supplied get_balance returns an object
containing the amount of sol an address has but also the balance of each
of it's non-TOKEN2022 token accounts
- 
  
## Implementation Details
<!-- Provide technical details about the implementation -->
- While reading the Solana RPC docs I came across an endpoint that
fetches all the token accounts owned by an address returns their
balances along with that.
- So with this data I thought why not just implement get_balance to be
more comprehensive than just simple `number` values

## Transaction executed by agent 
<!-- If applicable, provide example usage, transactions, or screenshots
-->
Example transaction: 
<img width="610" alt="Screenshot 2025-01-11 at 18 10 20"
src="https://github.com/user-attachments/assets/f242af3c-8703-42aa-8a65-a6dd9b369392"
/>



## Prompt Used
<!-- If relevant, include the prompt or configuration used -->
```
What's my balance
```

## Additional Notes
<!-- Any additional information that reviewers should know -->

## Checklist
- [x] I have tested these changes locally
- [ ] I have updated the documentation
- [ ] I have added a transaction link
- [x] I have added the prompt used to test it
This commit is contained in:
aryan
2025-01-15 20:40:53 +05:30
committed by GitHub
8 changed files with 241 additions and 3 deletions

View File

@@ -1,3 +1,4 @@
import tokenBalancesAction from "./tokenBalances";
import deployTokenAction from "./metaplex/deployToken";
import balanceAction from "./solana/balance";
import transferAction from "./solana/transfer";
@@ -61,6 +62,7 @@ import updateDriftVaultDelegateAction from "./drift/updateDriftVaultDelegate";
export const ACTIONS = {
WALLET_ADDRESS_ACTION: getWalletAddressAction,
TOKEN_BALANCES_ACTION: tokenBalancesAction,
DEPLOY_TOKEN_ACTION: deployTokenAction,
BALANCE_ACTION: balanceAction,
TRANSFER_ACTION: transferAction,

View File

@@ -1,6 +1,6 @@
import { PublicKey } from "@solana/web3.js";
import { Action } from "../../types/action";
import { SolanaAgentKit } from "../../agent";
import type { Action } from "../../types/action";
import type { SolanaAgentKit } from "../../agent";
import { z } from "zod";
import { get_balance } from "../../tools";

View File

@@ -0,0 +1,80 @@
import { PublicKey } from "@solana/web3.js";
import type { Action } from "../types/action";
import type { SolanaAgentKit } from "../agent";
import { z } from "zod";
import { get_token_balance } from "../tools";
const tokenBalancesAction: Action = {
name: "TOKEN_BALANCE_ACTION",
similes: [
"check token balances",
"get wallet token balances",
"view token balances",
"show token balances",
"check token balance",
],
description: `Get the token balances of a Solana wallet.
If you want to get the balance of your wallet, you don't need to provide the wallet address.`,
examples: [
[
{
input: {},
output: {
status: "success",
balance: {
sol: 100,
tokens: [
{
tokenAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v",
name: "USD Coin",
symbol: "USDC",
balance: 100,
decimals: 9,
},
],
},
},
explanation: "Get token balances of the wallet",
},
],
[
{
input: {
walletAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v",
},
output: {
status: "success",
balance: {
sol: 100,
tokens: [
{
tokenAddress: "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v",
name: "USD Coin",
symbol: "USDC",
balance: 100,
decimals: 9,
},
],
},
},
explanation: "Get address token balance",
},
],
],
schema: z.object({
walletAddress: z.string().optional(),
}),
handler: async (agent: SolanaAgentKit, input) => {
const balance = await get_token_balance(
agent,
input.tokenAddress && new PublicKey(input.tokenAddress),
);
return {
status: "success",
balance: balance,
};
},
};
export default tokenBalancesAction;

View File

@@ -97,6 +97,7 @@ import {
withdrawFromDriftUserAccount,
withdrawFromDriftVault,
updateVaultDelegate,
get_token_balance,
} from "../tools";
import {
Config,
@@ -189,6 +190,19 @@ export class SolanaAgentKit {
return get_balance(this, token_address);
}
async getTokenBalances(wallet_address?: PublicKey): Promise<{
sol: number;
tokens: Array<{
tokenAddress: string;
name: string;
symbol: string;
balance: number;
decimals: number;
}>;
}> {
return get_token_balance(this, wallet_address);
}
async getBalanceOther(
walletAddress: PublicKey,
tokenAddress?: PublicKey,

View File

@@ -0,0 +1,59 @@
import { LAMPORTS_PER_SOL, type PublicKey } from "@solana/web3.js";
import type { SolanaAgentKit } from "../../index";
import { TOKEN_PROGRAM_ID } from "@solana/spl-token";
import { getTokenMetadata } from "../../utils/tokenMetadata";
/**
* Get the token balances of a Solana wallet
* @param agent - SolanaAgentKit instance
* @param token_address - Optional SPL token mint address. If not provided, returns SOL balance
* @returns Promise resolving to the balance as an object containing sol balance and token balances with their respective mints, symbols, names and decimals
*/
export async function get_token_balance(
agent: SolanaAgentKit,
walletAddress?: PublicKey,
): Promise<{
sol: number;
tokens: Array<{
tokenAddress: string;
name: string;
symbol: string;
balance: number;
decimals: number;
}>;
}> {
const [lamportsBalance, tokenAccountData] = await Promise.all([
agent.connection.getBalance(walletAddress ?? agent.wallet_address),
agent.connection.getParsedTokenAccountsByOwner(
walletAddress ?? agent.wallet_address,
{
programId: TOKEN_PROGRAM_ID,
},
),
]);
const removedZeroBalance = tokenAccountData.value.filter(
(v) => v.account.data.parsed.info.tokenAmount.uiAmount !== 0,
);
const tokenBalances = await Promise.all(
removedZeroBalance.map(async (v) => {
const mint = v.account.data.parsed.info.mint;
const mintInfo = await getTokenMetadata(agent.connection, mint);
return {
tokenAddress: mint,
name: mintInfo.name ?? "",
symbol: mintInfo.symbol ?? "",
balance: v.account.data.parsed.info.tokenAmount.uiAmount as number,
decimals: v.account.data.parsed.info.tokenAmount.decimals as number,
};
}),
);
const solBalance = lamportsBalance / LAMPORTS_PER_SOL;
return {
sol: solBalance,
tokens: tokenBalances,
};
}

View File

@@ -4,3 +4,4 @@ export * from "./close_empty_token_accounts";
export * from "./transfer";
export * from "./get_balance";
export * from "./get_balance_other";
export * from "./get_token_balances";

View File

@@ -0,0 +1,83 @@
import { Connection, PublicKey } from "@solana/web3.js";
export async function getTokenMetadata(
connection: Connection,
tokenMint: string,
) {
const METADATA_PROGRAM_ID = new PublicKey(
"metaqbxxUerdq28cj1RbAWkYQm3ybzjb6a8bt518x1s",
);
const [metadataPDA] = PublicKey.findProgramAddressSync(
[
Buffer.from("metadata"),
METADATA_PROGRAM_ID.toBuffer(),
new PublicKey(tokenMint).toBuffer(),
],
METADATA_PROGRAM_ID,
);
const metadata = await connection.getAccountInfo(metadataPDA);
if (!metadata?.data) {
throw new Error("Metadata not found");
}
let offset = 1 + 32 + 32; // key + update auth + mint
const data = metadata.data;
const decoder = new TextDecoder();
// Read variable length strings
const readString = () => {
let nameLength = data[offset];
while (nameLength === 0) {
offset++;
nameLength = data[offset];
if (offset >= data.length) {
return null;
}
}
offset++;
const name = decoder
.decode(data.slice(offset, offset + nameLength))
// @eslint-disable-next-line no-control-regex
.replace(new RegExp(String.fromCharCode(0), "g"), "");
offset += nameLength;
return name;
};
const name = readString();
const symbol = readString();
const uri = readString();
// Read remaining data
const sellerFeeBasisPoints = data.readUInt16LE(offset);
offset += 2;
let creators:
| { address: PublicKey; verified: boolean; share: number }[]
| null = null;
if (data[offset] === 1) {
offset++;
const numCreators = data[offset];
offset++;
creators = [...Array(numCreators)].map(() => {
const creator = {
address: new PublicKey(data.slice(offset, offset + 32)),
verified: data[offset + 32] === 1,
share: data[offset + 33],
};
offset += 34;
return creator;
});
}
return {
name,
symbol,
uri,
sellerFeeBasisPoints,
creators,
};
}

View File

@@ -95,7 +95,6 @@ async function runChatMode() {
);
const tools = createVercelAITools(solanaAgent);
console.log(tools);
const rl = readline.createInterface({
input: process.stdin,