From c26b92e7cd40c1eeb46af36dc20640bc609907fa Mon Sep 17 00:00:00 2001 From: Damjan Date: Fri, 27 Dec 2024 16:19:52 +0100 Subject: [PATCH] Add support for querying other wallet balance --- src/agent/index.ts | 12 +++++-- src/langchain/index.ts | 60 +++++++++++++++++++++++++++++----- src/tools/get_balance_other.ts | 50 ++++++++++++++++++++++++++++ src/tools/index.ts | 1 + 4 files changed, 113 insertions(+), 10 deletions(-) create mode 100644 src/tools/get_balance_other.ts diff --git a/src/agent/index.ts b/src/agent/index.ts index 56f66f4..5f420a4 100644 --- a/src/agent/index.ts +++ b/src/agent/index.ts @@ -6,6 +6,7 @@ import { deploy_collection, deploy_token, get_balance, + get_balance_other, getTPS, resolveSolDomain, getPrimaryDomain, @@ -60,12 +61,12 @@ export class SolanaAgentKit { public connection: Connection; public wallet: Keypair; public wallet_address: PublicKey; - public openai_api_key: string; + public openai_api_key: string | null; constructor( private_key: string, rpc_url = "https://api.mainnet-beta.solana.com", - openai_api_key: string, + openai_api_key: string | null = null, ) { this.connection = new Connection(rpc_url); this.wallet = Keypair.fromSecretKey(bs58.decode(private_key)); @@ -98,6 +99,13 @@ export class SolanaAgentKit { return get_balance(this, token_address); } + async getBalanceOther( + walletAddress: PublicKey, + tokenAddress?: PublicKey, + ): Promise { + return get_balance_other(this, walletAddress, tokenAddress); + } + async mintNFT( collectionMint: PublicKey, metadata: Parameters[2], diff --git a/src/langchain/index.ts b/src/langchain/index.ts index f28b000..502c686 100644 --- a/src/langchain/index.ts +++ b/src/langchain/index.ts @@ -32,7 +32,7 @@ export class SolanaBalanceTool extends Tool { return JSON.stringify({ status: "success", - balance: balance, + balance, token: input || "SOL", }); } catch (error: any) { @@ -45,6 +45,49 @@ export class SolanaBalanceTool extends Tool { } } +export class SolanaBalanceOtherTool extends Tool { + name = "solana_balance_other"; + description = `Get the balance of a Solana wallet or token account different from the agent's wallet. + + If no tokenAddress is provided, the SOL balance of the wallet will be returned. + + Inputs: + walletAddress: string, eg "GDEkQF7UMr7RLv1KQKMtm8E2w3iafxJLtyXu3HVQZnME" (required) + tokenAddress: string, eg "SENDdRQtYMWaQrBroBrJ2Q53fgVuq95CV9UPGEvpCxa" (optional)`; + + constructor(private solanaKit: SolanaAgentKit) { + super(); + } + + protected async _call(input: string): Promise { + try { + const { walletAddress, tokenAddress } = JSON.parse(input); + + const tokenPubKey = tokenAddress + ? new PublicKey(tokenAddress) + : undefined; + + const balance = await this.solanaKit.getBalanceOther( + new PublicKey(walletAddress), + tokenPubKey, + ); + + return JSON.stringify({ + status: "success", + balance, + wallet: walletAddress, + token: tokenAddress || "SOL", + }); + } catch (error: any) { + return JSON.stringify({ + status: "error", + message: error.message, + code: error.code || "UNKNOWN_ERROR", + }); + } + } +} + export class SolanaTransferTool extends Tool { name = "solana_transfer"; description = `Transfer tokens or SOL to another address ( also called as wallet address ). @@ -555,7 +598,7 @@ export class SolanaLendAssetTool extends Tool { status: "success", message: "Asset lent successfully", transaction: tx, - amount: amount, + amount, }); } catch (error: any) { return JSON.stringify({ @@ -669,7 +712,7 @@ export class SolanaTokenDataTool extends Tool { return JSON.stringify({ status: "success", - tokenData: tokenData, + tokenData, }); } catch (error: any) { return JSON.stringify({ @@ -698,7 +741,7 @@ export class SolanaTokenDataByTickerTool extends Tool { const tokenData = await this.solanaKit.getTokenDataByTicker(ticker); return JSON.stringify({ status: "success", - tokenData: tokenData, + tokenData, }); } catch (error: any) { return JSON.stringify({ @@ -1005,7 +1048,7 @@ export class SolanaPythFetchPrice extends Tool { const response: PythFetchPriceResponse = { status: "success", priceFeedID: input, - price: price, + price, }; return JSON.stringify(response); } catch (error: any) { @@ -1079,7 +1122,7 @@ export class SolanaGetOwnedDomains extends Tool { return JSON.stringify({ status: "success", message: "Owned domains fetched successfully", - domains: domains, + domains, }); } catch (error: any) { return JSON.stringify({ @@ -1109,7 +1152,7 @@ export class SolanaGetOwnedTldDomains extends Tool { return JSON.stringify({ status: "success", message: "TLD domains fetched successfully", - domains: domains, + domains, }); } catch (error: any) { return JSON.stringify({ @@ -1136,7 +1179,7 @@ export class SolanaGetAllTlds extends Tool { return JSON.stringify({ status: "success", message: "TLDs fetched successfully", - tlds: tlds, + tlds, }); } catch (error: any) { return JSON.stringify({ @@ -1232,6 +1275,7 @@ export class SolanaCreateGibworkTask extends Tool { export function createSolanaTools(solanaKit: SolanaAgentKit) { return [ new SolanaBalanceTool(solanaKit), + new SolanaBalanceOtherTool(solanaKit), new SolanaTransferTool(solanaKit), new SolanaDeployTokenTool(solanaKit), new SolanaDeployCollectionTool(solanaKit), diff --git a/src/tools/get_balance_other.ts b/src/tools/get_balance_other.ts new file mode 100644 index 0000000..84c1c32 --- /dev/null +++ b/src/tools/get_balance_other.ts @@ -0,0 +1,50 @@ +import { + LAMPORTS_PER_SOL, + ParsedAccountData, + PublicKey, +} from "@solana/web3.js"; +import { SolanaAgentKit } from "../index"; + +/** + * Get the balance of SOL or an SPL token for the specified wallet address (other than the agent's wallet) + * @param agent - SolanaAgentKit instance + * @param wallet_address - Public key of the wallet to check balance for + * @param token_address - Optional SPL token mint address. If not provided, returns SOL balance + * @returns Promise resolving to the balance as a number (in UI units) or 0 if account doesn't exist + */ +export async function get_balance_other( + agent: SolanaAgentKit, + wallet_address: PublicKey, + token_address?: PublicKey, +): Promise { + try { + if (!token_address) { + return ( + (await agent.connection.getBalance(wallet_address)) / LAMPORTS_PER_SOL + ); + } + + const tokenAccounts = await agent.connection.getTokenAccountsByOwner( + wallet_address, + { mint: token_address }, + ); + + if (tokenAccounts.value.length === 0) { + console.warn( + `No token accounts found for wallet ${wallet_address.toString()} and token ${token_address.toString()}`, + ); + return 0; + } + + const tokenAccount = await agent.connection.getParsedAccountInfo( + tokenAccounts.value[0].pubkey, + ); + const tokenData = tokenAccount.value?.data as ParsedAccountData; + + return tokenData.parsed?.info?.tokenAmount?.uiAmount || 0; + } catch (error) { + throw new Error( + `Error fetching on-chain balance for ${token_address?.toString()}: ${error}`, + ); + } +} diff --git a/src/tools/index.ts b/src/tools/index.ts index 9a18cf4..ca61113 100644 --- a/src/tools/index.ts +++ b/src/tools/index.ts @@ -2,6 +2,7 @@ export * from "./request_faucet_funds"; export * from "./deploy_token"; export * from "./deploy_collection"; export * from "./get_balance"; +export * from "./get_balance_other"; export * from "./mint_nft"; export * from "./transfer"; export * from "./trade";