From d277d0d2834155fab91bd909f0460d5628555499 Mon Sep 17 00:00:00 2001 From: michaelessiet Date: Wed, 1 Jan 2025 20:46:56 +0100 Subject: [PATCH] feat: get all token balances --- .../app/api/chat/route.ts | 80 +++++++++--------- .../data/DefaultRetrievalText.ts | 2 +- .../utils/markdownToHTML.ts | 30 +++---- src/actions/balance.ts | 1 - src/actions/createOrcaSingleSidedWhirlpool.ts | 2 +- src/actions/index.ts | 57 ++++++------- src/agent/index.ts | 30 +++++-- src/tools/get_balance.ts | 49 ++++++++++- src/utils/tokenMetadata.ts | 83 +++++++++++++++++++ test/index.ts | 2 +- 10 files changed, 239 insertions(+), 97 deletions(-) create mode 100644 src/utils/tokenMetadata.ts diff --git a/examples/agent-kit-nextjs-langchain/app/api/chat/route.ts b/examples/agent-kit-nextjs-langchain/app/api/chat/route.ts index 06e6911..badd6b3 100644 --- a/examples/agent-kit-nextjs-langchain/app/api/chat/route.ts +++ b/examples/agent-kit-nextjs-langchain/app/api/chat/route.ts @@ -5,24 +5,24 @@ import { createReactAgent } from "@langchain/langgraph/prebuilt"; import { SolanaAgentKit, createSolanaTools } from "solana-agent-kit"; const llm = new ChatOpenAI({ - temperature: 0.7, - model: "gpt-4o-mini", + temperature: 0.7, + model: "gpt-4o-mini", }); const solanaAgent = new SolanaAgentKit( - process.env.SOLANA_PRIVATE_KEY!, - process.env.RPC_URL, - process.env.OPENAI_API_KEY!, + process.env.SOLANA_PRIVATE_KEY!, + process.env.RPC_URL, + process.env.OPENAI_API_KEY!, ); const tools = createSolanaTools(solanaAgent); const memory = new MemorySaver(); const agent = createReactAgent({ - llm, - tools, - checkpointSaver: memory, - messageModifier: ` + llm, + tools, + checkpointSaver: memory, + messageModifier: ` You are a helpful agent that can interact onchain using the Solana Agent Kit. You are empowered to interact onchain using your tools. If you ever need funds, you can request them from the faucet. If not, you can provide your wallet details and request funds from the user. If there is a 5XX @@ -34,38 +34,38 @@ const agent = createReactAgent({ }); export async function POST(req: NextRequest) { - try { - const body = await req.json(); - const messages = body.messages ?? []; + try { + const body = await req.json(); + const messages = body.messages ?? []; - const eventStream = agent.streamEvents( - { - messages, - }, - { - version: "v2", - configurable: { - thread_id: "Solana Agent Kit!", - }, - }, - ); + const eventStream = agent.streamEvents( + { + messages, + }, + { + version: "v2", + configurable: { + thread_id: "Solana Agent Kit!", + }, + }, + ); - const textEncoder = new TextEncoder(); - const transformStream = new ReadableStream({ - async start(controller) { - for await (const { event, data } of eventStream) { - if (event === "on_chat_model_stream") { - if (!!data.chunk.content) { - controller.enqueue(textEncoder.encode(data.chunk.content)); - } - } - } - controller.close(); - }, - }); + const textEncoder = new TextEncoder(); + const transformStream = new ReadableStream({ + async start(controller) { + for await (const { event, data } of eventStream) { + if (event === "on_chat_model_stream") { + if (data.chunk.content) { + controller.enqueue(textEncoder.encode(data.chunk.content)); + } + } + } + controller.close(); + }, + }); - return new Response(transformStream); - } catch (e: any) { - return NextResponse.json({ error: e.message }, { status: e.status ?? 500 }); - } + return new Response(transformStream); + } catch (e: any) { + return NextResponse.json({ error: e.message }, { status: e.status ?? 500 }); + } } diff --git a/examples/agent-kit-nextjs-langchain/data/DefaultRetrievalText.ts b/examples/agent-kit-nextjs-langchain/data/DefaultRetrievalText.ts index 898acba..6973d98 100644 --- a/examples/agent-kit-nextjs-langchain/data/DefaultRetrievalText.ts +++ b/examples/agent-kit-nextjs-langchain/data/DefaultRetrievalText.ts @@ -537,4 +537,4 @@ const executor = await initializeAgentExecutorWithOptions(tools, llm, { }, }); \`\`\` -`; \ No newline at end of file +`; diff --git a/examples/agent-kit-nextjs-langchain/utils/markdownToHTML.ts b/examples/agent-kit-nextjs-langchain/utils/markdownToHTML.ts index dc265b1..135fdd9 100644 --- a/examples/agent-kit-nextjs-langchain/utils/markdownToHTML.ts +++ b/examples/agent-kit-nextjs-langchain/utils/markdownToHTML.ts @@ -2,29 +2,29 @@ import { marked } from "marked"; import DOMPurify from "isomorphic-dompurify"; interface MarkedOptions { - gfm: boolean; - breaks: boolean; - headerIds: boolean; - mangle: false; - highlight?: (code: string, lang: string) => string; + gfm: boolean; + breaks: boolean; + headerIds: boolean; + mangle: false; + highlight?: (code: string, lang: string) => string; } // Configure marked options const markedOptions: MarkedOptions = { - gfm: true, // GitHub Flavored Markdown - breaks: true, // Convert \n to
- headerIds: true, // Add ids to headers - mangle: false, // Don't escape HTML - highlight: function (code: string, lang: string): string { - // You can add syntax highlighting here if needed - return code; - }, + gfm: true, // GitHub Flavored Markdown + breaks: true, // Convert \n to
+ headerIds: true, // Add ids to headers + mangle: false, // Don't escape HTML + highlight: function (code: string, lang: string): string { + // You can add syntax highlighting here if needed + return code; + }, }; marked.setOptions(markedOptions); // Basic markdown to HTML conversion with sanitization export default function markdownToHtml(markdown: string) { - const rawHtml = marked.parse(markdown); - return DOMPurify.sanitize(rawHtml as string); + const rawHtml = marked.parse(markdown); + return DOMPurify.sanitize(rawHtml as string); } diff --git a/src/actions/balance.ts b/src/actions/balance.ts index 381b2a5..1ee1b56 100644 --- a/src/actions/balance.ts +++ b/src/actions/balance.ts @@ -54,7 +54,6 @@ const balanceAction: Action = { return { status: "success", balance: balance, - token: input.tokenAddress || "SOL", }; }, }; diff --git a/src/actions/createOrcaSingleSidedWhirlpool.ts b/src/actions/createOrcaSingleSidedWhirlpool.ts index affc445..3a1fb52 100644 --- a/src/actions/createOrcaSingleSidedWhirlpool.ts +++ b/src/actions/createOrcaSingleSidedWhirlpool.ts @@ -87,7 +87,7 @@ const createOrcaSingleSidedWhirlpoolAction: Action = { const otherTokenMint = new PublicKey(input.otherTokenMint); const initialPrice = new Decimal(input.initialPrice); const maxPrice = new Decimal(input.maxPrice); - const feeTier = input.feeTier + const feeTier = input.feeTier; // Create the whirlpool const signature = await orcaCreateSingleSidedLiquidityPool( diff --git a/src/actions/index.ts b/src/actions/index.ts index b66c89e..1ccdb64 100644 --- a/src/actions/index.ts +++ b/src/actions/index.ts @@ -28,34 +28,35 @@ import createOrcaSingleSidedWhirlpoolAction from "./createOrcaSingleSidedWhirlpo import launchPumpfunTokenAction from "./launchPumpfunToken"; export const ACTIONS = { - "DEPLOY_TOKEN_ACTION" : deployTokenAction, - "BALANCE_ACTION" : balanceAction, - "TRANSFER_ACTION" : transferAction, - "DEPLOY_COLLECTION_ACTION" : deployCollectionAction, - "MINT_NFT_ACTION" : mintNFTAction, - "TRADE_ACTION" : tradeAction, - "REQUEST_FUNDS_ACTION" : requestFundsAction, - "RESOLVE_DOMAIN_ACTION" : resolveDomainAction, - "GET_TOKEN_DATA_ACTION" : getTokenDataAction, - "GET_TPS_ACTION" : getTPSAction, - "FETCH_PRICE_ACTION" : fetchPriceAction, - "STAKE_WITH_JUP_ACTION" : stakeWithJupAction, - "REGISTER_DOMAIN_ACTION" : registerDomainAction, - "LEND_ASSET_ACTION" : lendAssetAction, - "CREATE_GIBWORK_TASK_ACTION" : createGibworkTaskAction, - "RESOLVE_SOL_DOMAIN_ACTION" : resolveSolDomainAction, - "PYTH_FETCH_PRICE_ACTION" : pythFetchPriceAction, - "GET_OWNED_DOMAINS_FOR_TLD_ACTION" : getOwnedDomainsForTLDAction, - "GET_PRIMARY_DOMAIN_ACTION" : getPrimaryDomainAction, - "GET_ALL_DOMAINS_TLDS_ACTION" : getAllDomainsTLDsAction, - "GET_OWNED_ALL_DOMAINS_ACTION" : getOwnedAllDomainsAction, - "CREATE_IMAGE_ACTION" : createImageAction, - "GET_MAIN_ALL_DOMAINS_DOMAIN_ACTION" : getMainAllDomainsDomainAction, - "GET_ALL_REGISTERED_ALL_DOMAINS_ACTION" : getAllRegisteredAllDomainsAction, - "RAYDIUM_CREATE_CPMM_ACTION" : raydiumCreateCpmmAction, - "RAYDIUM_CREATE_AMM_V4_ACTION" : raydiumCreateAmmV4Action, - "CREATE_ORCA_SINGLE_SIDED_WHIRLPOOL_ACTION" : createOrcaSingleSidedWhirlpoolAction, - "LAUNCH_PUMPFUN_TOKEN_ACTION" : launchPumpfunTokenAction, + DEPLOY_TOKEN_ACTION: deployTokenAction, + BALANCE_ACTION: balanceAction, + TRANSFER_ACTION: transferAction, + DEPLOY_COLLECTION_ACTION: deployCollectionAction, + MINT_NFT_ACTION: mintNFTAction, + TRADE_ACTION: tradeAction, + REQUEST_FUNDS_ACTION: requestFundsAction, + RESOLVE_DOMAIN_ACTION: resolveDomainAction, + GET_TOKEN_DATA_ACTION: getTokenDataAction, + GET_TPS_ACTION: getTPSAction, + FETCH_PRICE_ACTION: fetchPriceAction, + STAKE_WITH_JUP_ACTION: stakeWithJupAction, + REGISTER_DOMAIN_ACTION: registerDomainAction, + LEND_ASSET_ACTION: lendAssetAction, + CREATE_GIBWORK_TASK_ACTION: createGibworkTaskAction, + RESOLVE_SOL_DOMAIN_ACTION: resolveSolDomainAction, + PYTH_FETCH_PRICE_ACTION: pythFetchPriceAction, + GET_OWNED_DOMAINS_FOR_TLD_ACTION: getOwnedDomainsForTLDAction, + GET_PRIMARY_DOMAIN_ACTION: getPrimaryDomainAction, + GET_ALL_DOMAINS_TLDS_ACTION: getAllDomainsTLDsAction, + GET_OWNED_ALL_DOMAINS_ACTION: getOwnedAllDomainsAction, + CREATE_IMAGE_ACTION: createImageAction, + GET_MAIN_ALL_DOMAINS_DOMAIN_ACTION: getMainAllDomainsDomainAction, + GET_ALL_REGISTERED_ALL_DOMAINS_ACTION: getAllRegisteredAllDomainsAction, + RAYDIUM_CREATE_CPMM_ACTION: raydiumCreateCpmmAction, + RAYDIUM_CREATE_AMM_V4_ACTION: raydiumCreateAmmV4Action, + CREATE_ORCA_SINGLE_SIDED_WHIRLPOOL_ACTION: + createOrcaSingleSidedWhirlpoolAction, + LAUNCH_PUMPFUN_TOKEN_ACTION: launchPumpfunTokenAction, }; export type { Action, ActionExample, Handler } from "../types/action"; diff --git a/src/agent/index.ts b/src/agent/index.ts index 4acf694..8191a80 100644 --- a/src/agent/index.ts +++ b/src/agent/index.ts @@ -83,24 +83,30 @@ export class SolanaAgentKit { * @deprecated Using openai_api_key directly in constructor is deprecated. * Please use the new constructor with Config object instead: * @example - * const agent = new SolanaAgentKit(privateKey, rpcUrl, { + * const agent = new SolanaAgentKit(privateKey, rpcUrl, { * OPENAI_API_KEY: 'your-key' * }); */ - constructor(private_key: string, rpc_url: string, openai_api_key: string | null); + constructor( + private_key: string, + rpc_url: string, + openai_api_key: string | null, + ); constructor(private_key: string, rpc_url: string, config: Config); constructor( private_key: string, rpc_url: string, configOrKey: Config | string | null, ) { - this.connection = new Connection(rpc_url || "https://api.mainnet-beta.solana.com"); + this.connection = new Connection( + rpc_url || "https://api.mainnet-beta.solana.com", + ); this.wallet = Keypair.fromSecretKey(bs58.decode(private_key)); this.wallet_address = this.wallet.publicKey; // Handle both old and new patterns - if (typeof configOrKey === 'string' || configOrKey === null) { - this.config = { OPENAI_API_KEY: configOrKey || '' }; + if (typeof configOrKey === "string" || configOrKey === null) { + this.config = { OPENAI_API_KEY: configOrKey || "" }; } else { this.config = configOrKey; } @@ -127,7 +133,19 @@ export class SolanaAgentKit { return deploy_collection(this, options); } - async getBalance(token_address?: PublicKey): Promise { + async getBalance(token_address?: PublicKey): Promise< + | number + | { + sol: number; + tokens: Array<{ + tokenAddress: string; + name: string; + symbol: string; + balance: number; + decimals: number; + }>; + } + > { return get_balance(this, token_address); } diff --git a/src/tools/get_balance.ts b/src/tools/get_balance.ts index a1e3736..3e30021 100644 --- a/src/tools/get_balance.ts +++ b/src/tools/get_balance.ts @@ -1,5 +1,7 @@ import { LAMPORTS_PER_SOL, PublicKey } from "@solana/web3.js"; import { SolanaAgentKit } from "../index"; +import { TOKEN_PROGRAM_ID } from "@solana/spl-token"; +import { getTokenMetadata } from "../utils/tokenMetadata"; /** * Get the balance of SOL or an SPL token for the agent's wallet @@ -10,12 +12,51 @@ import { SolanaAgentKit } from "../index"; export async function get_balance( agent: SolanaAgentKit, token_address?: PublicKey, -): Promise { +): Promise< + | number + | { + sol: number; + tokens: Array<{ + tokenAddress: string; + name: string; + symbol: string; + balance: number; + decimals: number; + }>; + } +> { if (!token_address) { - return ( - (await agent.connection.getBalance(agent.wallet_address)) / - LAMPORTS_PER_SOL + const [lamportsBalance, tokenAccountData] = await Promise.all([ + agent.connection.getBalance(agent.wallet_address), + agent.connection.getParsedTokenAccountsByOwner(agent.wallet_address, { + programId: TOKEN_PROGRAM_ID, + }), + ]); + + const removedZeroBalance = tokenAccountData.value.filter( + (v) => v.account.data.parsed.info.tokenAmount.uiAmount !== 0, ); + + const tokenBalances = await Promise.all( + removedZeroBalance.map(async (v) => { + const mint = v.account.data.parsed.info.mint; + const mintInfo = await getTokenMetadata(agent.connection, mint); + return { + tokenAddress: mint, + name: mintInfo.name ?? "", + symbol: mintInfo.symbol ?? "", + balance: v.account.data.parsed.info.tokenAmount.uiAmount as number, + decimals: v.account.data.parsed.info.tokenAmount.decimals as number, + }; + }), + ); + + const solBalance = lamportsBalance / LAMPORTS_PER_SOL; + + return { + sol: solBalance, + tokens: tokenBalances, + }; } const token_account = diff --git a/src/utils/tokenMetadata.ts b/src/utils/tokenMetadata.ts new file mode 100644 index 0000000..514a208 --- /dev/null +++ b/src/utils/tokenMetadata.ts @@ -0,0 +1,83 @@ +import { Connection, PublicKey } from "@solana/web3.js"; + +export async function getTokenMetadata( + connection: Connection, + tokenMint: string, +) { + const METADATA_PROGRAM_ID = new PublicKey( + "metaqbxxUerdq28cj1RbAWkYQm3ybzjb6a8bt518x1s", + ); + + const [metadataPDA] = PublicKey.findProgramAddressSync( + [ + Buffer.from("metadata"), + METADATA_PROGRAM_ID.toBuffer(), + new PublicKey(tokenMint).toBuffer(), + ], + METADATA_PROGRAM_ID, + ); + + const metadata = await connection.getAccountInfo(metadataPDA); + if (!metadata?.data) { + throw new Error("Metadata not found"); + } + + let offset = 1 + 32 + 32; // key + update auth + mint + const data = metadata.data; + const decoder = new TextDecoder(); + + // Read variable length strings + const readString = () => { + let nameLength = data[offset]; + + while (nameLength === 0) { + offset++; + nameLength = data[offset]; + if (offset >= data.length) { + return null; + } + } + + offset++; + const name = decoder + .decode(data.slice(offset, offset + nameLength)) + // @eslint-disable-next-line no-control-regex + .replace(new RegExp(String.fromCharCode(0), "g"), ""); + offset += nameLength; + return name; + }; + + const name = readString(); + const symbol = readString(); + const uri = readString(); + + // Read remaining data + const sellerFeeBasisPoints = data.readUInt16LE(offset); + offset += 2; + + let creators: + | { address: PublicKey; verified: boolean; share: number }[] + | null = null; + if (data[offset] === 1) { + offset++; + const numCreators = data[offset]; + offset++; + creators = [...Array(numCreators)].map(() => { + const creator = { + address: new PublicKey(data.slice(offset, offset + 32)), + verified: data[offset + 32] === 1, + share: data[offset + 33], + }; + offset += 34; + return creator; + }); + } + + return { + name, + symbol, + uri, + sellerFeeBasisPoints, + creators, + }; +} diff --git a/test/index.ts b/test/index.ts index 2a3312b..00f9976 100644 --- a/test/index.ts +++ b/test/index.ts @@ -1,4 +1,4 @@ -import { SolanaAgentKit , ACTIONS} from "../src"; +import { SolanaAgentKit, ACTIONS } from "../src"; import { createSolanaTools } from "../src/langchain"; import { HumanMessage } from "@langchain/core/messages"; import { MemorySaver } from "@langchain/langgraph";