diff --git a/src/actions/drift/updateDriftVaultDelegate.ts b/src/actions/drift/updateDriftVaultDelegate.ts new file mode 100644 index 0000000..defc7a6 --- /dev/null +++ b/src/actions/drift/updateDriftVaultDelegate.ts @@ -0,0 +1,53 @@ +import { z } from "zod"; +import type { Action } from "../../types"; +import { updateVaultDelegate } from "../../tools"; + +const updateDriftVaultDelegateAction: Action = { + name: "UPDATE_DRIFT_VAULT_DELEGATE_ACTION", + similes: ["update drift vault delegate", "change drift vault delegate"], + description: "Update the delegate of a drift vault", + examples: [ + [ + { + input: { + vaultAddress: "2nFeP7taii3wGVgrWk4YiLMPmhtu3Zg9iXCUu4zGBD", + newDelegate: "2nFeP7tai", + }, + output: { + status: "success", + message: "Vault delegate updated successfully", + signature: + "2nFeP7taii3wGVgrWk4YiLMPmhtu3Zg9iXCUu4zGBDadwunHw8reXFxRWT7khbFsQ9JT3zK4RYDLNDFDRYvM3wJk", + }, + explanation: "Update the delegate of a drift vault to another address", + }, + ], + ], + schema: z.object({ + vaultAddress: z.string(), + newDelegate: z.string(), + }), + handler: async (agent, input) => { + try { + const tx = await updateVaultDelegate( + agent, + input.vaultAddress as string, + input.newDelegate as string, + ); + + return { + status: "success", + message: "Vault delegate updated successfully", + signature: tx, + }; + } catch (e) { + return { + status: "error", + // @ts-expect-error - error message + message: `Failed to update vault delegate: ${e.message}`, + }; + } + }, +}; + +export default updateDriftVaultDelegateAction; diff --git a/src/actions/drift/updateVault.ts b/src/actions/drift/updateVault.ts index 530e1ce..4d0f66e 100644 --- a/src/actions/drift/updateVault.ts +++ b/src/actions/drift/updateVault.ts @@ -32,14 +32,25 @@ const updateDriftVaultAction: Action = { ], schema: z.object({ vaultAddress: z.string(), - name: z.string().min(5, "Name must be at least 5 characters"), + name: z.string().min(5, "Name must be at least 5 characters").optional(), // regex matches SOL-SPOT - marketName: z.string().regex(/^([A-Za-z0-9]{2,7})-SPOT$/), - redeemPeriod: z.number().int().min(1, "Redeem period must be at least 1"), - maxTokens: z.number().int().min(100, "Max tokens must be at least 100"), - minDepositAmount: z.number().positive(), - managementFee: z.number().positive().max(20), - profitShare: z.number().positive().max(90).optional().default(5), + marketName: z + .string() + .regex(/^([A-Za-z0-9]{2,7})-SPOT$/) + .optional(), + redeemPeriod: z + .number() + .int() + .min(1, "Redeem period must be at least 1") + .optional(), + maxTokens: z + .number() + .int() + .min(100, "Max tokens must be at least 100") + .optional(), + minDepositAmount: z.number().positive().optional(), + managementFee: z.number().positive().max(20).optional(), + profitShare: z.number().positive().max(90).optional(), handleRate: z.number().optional(), permissioned: z .boolean() diff --git a/src/actions/index.ts b/src/actions/index.ts index 4664ded..63ec7c5 100644 --- a/src/actions/index.ts +++ b/src/actions/index.ts @@ -57,6 +57,7 @@ import depositToDriftUserAccountAction from "./drift/depositToDriftUserAccount"; import withdrawFromDriftAccountAction from "./drift/withdrawFromDriftAccount"; import driftUserAccountInfoAction from "./drift/driftUserAccountInfo"; import deriveDriftVaultAddressAction from "./drift/deriveVaultAddress"; +import updateDriftVaultDelegateAction from "./drift/updateDriftVaultDelegate"; export const ACTIONS = { WALLET_ADDRESS_ACTION: getWalletAddressAction, @@ -119,6 +120,7 @@ export const ACTIONS = { WITHDRAW_OR_BORROW_FROM_DRIFT_ACCOUNT_ACTION: withdrawFromDriftAccountAction, DRIFT_USER_ACCOUNT_INFO_ACTION: driftUserAccountInfoAction, DERIVE_DRIFT_VAULT_ADDRESS_ACTION: deriveDriftVaultAddressAction, + UPDATE_DRIFT_VAULT_DELEGATE_ACTION: updateDriftVaultDelegateAction, }; export type { Action, ActionExample, Handler } from "../types/action"; diff --git a/src/agent/index.ts b/src/agent/index.ts index d889897..45f085f 100644 --- a/src/agent/index.ts +++ b/src/agent/index.ts @@ -96,6 +96,7 @@ import { getVaultInfo, withdrawFromDriftUserAccount, withdrawFromDriftVault, + updateVaultDelegate, } from "../tools"; import { Config, @@ -803,4 +804,7 @@ export class SolanaAgentKit { async withdrawFromDriftVault(vault: string) { return await withdrawFromDriftVault(this, vault); } + async updateDriftVaultDelegate(vaultAddress: string, delegate: string) { + return await updateVaultDelegate(this, vaultAddress, delegate); + } } diff --git a/src/tools/drift/drift.ts b/src/tools/drift/drift.ts index d112fca..97105a4 100644 --- a/src/tools/drift/drift.ts +++ b/src/tools/drift/drift.ts @@ -24,7 +24,14 @@ import { PublicKey } from "@solana/web3.js"; import { Transaction } from "@solana/web3.js"; import { ComputeBudgetProgram } from "@solana/web3.js"; -export async function initClients(agent: SolanaAgentKit) { +export async function initClients( + agent: SolanaAgentKit, + params?: { + authority: PublicKey; + activeSubAccountId: number; + subAccountIds: number[]; + }, +) { const wallet: IWallet = { publicKey: agent.wallet.publicKey, payer: agent.wallet, @@ -40,10 +47,17 @@ export async function initClients(agent: SolanaAgentKit) { }, }; + // @ts-expect-error - false undefined type conflict const driftClient = new DriftClient({ connection: agent.connection, wallet, env: "mainnet-beta", + authority: params?.authority, + activeSubAccountId: params?.activeSubAccountId, + subAccountIds: params?.subAccountIds, + txParams: { + computeUnitsPrice: 0.000001 * 1000000 * 1000000, + }, txSender: new FastSingleTxSender({ connection: agent.connection, wallet, diff --git a/src/tools/drift/drift_vault.ts b/src/tools/drift/drift_vault.ts index 3bfa182..a65c485 100644 --- a/src/tools/drift/drift_vault.ts +++ b/src/tools/drift/drift_vault.ts @@ -14,8 +14,10 @@ import { PRICE_PRECISION, QUOTE_PRECISION, TEN, + User, } from "@drift-labs/sdk"; import { + VaultAccount, WithdrawUnit, encodeName, getVaultAddressSync, @@ -34,7 +36,7 @@ export function getMarketIndexAndType(name: `${string}-${string}`) { const [symbol, type] = name.toUpperCase().split("-"); if (type === "PERP") { - const token = MainnetPerpMarkets.find((v) => v.symbol === symbol); + const token = MainnetPerpMarkets.find((v) => v.baseAssetSymbol === symbol); if (!token) { throw new Error("Drift doesn't have that market"); } @@ -71,6 +73,31 @@ async function getOrCreateVaultDepositor(agent: SolanaAgentKit, vault: string) { } } +async function getVaultAvailableBalance(agent: SolanaAgentKit, vault: string) { + try { + const { cleanUp, vaultClient } = await initClients(agent); + const vaultDetails = await vaultClient.getVault(new PublicKey(vault)); + + const currentVaultBalance = convertToNumber( + vaultDetails.netDeposits, + QUOTE_PRECISION, + ); + const vaultWithdrawalsRequested = convertToNumber( + vaultDetails.totalWithdrawRequested, + QUOTE_PRECISION, + ); + const availableBalanceInUSD = + currentVaultBalance - vaultWithdrawalsRequested; + + await cleanUp(); + + return availableBalanceInUSD; + } catch (e) { + // @ts-expect-error - error message is a string + throw new Error(`Failed to get vault available balance: ${e.message}`); + } +} + /** Create a vault @param agent SolanaAgentKit instance @@ -151,6 +178,27 @@ export async function createVault( } } +export async function updateVaultDelegate( + agent: SolanaAgentKit, + vault: string, + delegateAddress: string, +) { + try { + const { vaultClient, cleanUp } = await initClients(agent); + const signature = await vaultClient.updateDelegate( + new PublicKey(vault), + new PublicKey(delegateAddress), + ); + await cleanUp(); + return signature; + } catch (e) { + throw new Error( + // @ts-expect-error - error message is a string + `Failed to update vault delegate: ${e.message}`, + ); + } +} + /** Update the vault's info @param agent SolanaAgentKit instance @@ -194,34 +242,32 @@ export async function updateVault( const spotPrecision = TEN.pow(new BN(spotMarket.decimals)); const tx = await vaultClient.managerUpdateVault(vaultPublicKey, { - redeemPeriod: new BN( - params.redeemPeriod - ? params.redeemPeriod * 86400 - : vaultDetails.redeemPeriod, - ), + redeemPeriod: params.redeemPeriod + ? new BN(params.redeemPeriod * 86400) + : null, maxTokens: params.maxTokens ? numberToSafeBN(params.maxTokens, spotPrecision) - : vaultDetails.maxTokens, + : null, minDepositAmount: params.minDepositAmount ? numberToSafeBN(params.minDepositAmount, spotPrecision) - : vaultDetails.minDepositAmount, + : null, managementFee: params.managementFee ? new BN(params.managementFee) .mul(PERCENTAGE_PRECISION) .div(new BN(100)) - : vaultDetails.managementFee, + : null, profitShare: params.profitShare ? new BN(params.profitShare) .mul(PERCENTAGE_PRECISION) .div(new BN(100)) .toNumber() - : vaultDetails.profitShare, + : null, hurdleRate: params.hurdleRate ? new BN(params.hurdleRate) .mul(PERCENTAGE_PRECISION) .div(new BN(100)) .toNumber() - : vaultDetails.hurdleRate, + : null, permissioned: params.permissioned ?? vaultDetails.permissioned, }); @@ -234,6 +280,12 @@ export async function updateVault( } } +/** + * Get information on a particular vault given its name + * @param agent + * @param vaultName + * @returns + */ export async function getVaultInfo(agent: SolanaAgentKit, vaultName: string) { try { const { vaultClient, cleanUp } = await initClients(agent); @@ -241,15 +293,20 @@ export async function getVaultInfo(agent: SolanaAgentKit, vaultName: string) { vaultClient.program.programId, encodeName(vaultName), ); - const vaultDetails = await vaultClient.getVault(vaultPublicKey); + const [vaultDetails, vaultBalance] = await Promise.all([ + vaultClient.getVault(vaultPublicKey), + getVaultAvailableBalance(agent, vaultPublicKey.toBase58()), + ]); await cleanUp(); const spotToken = MainnetSpotMarkets[vaultDetails.spotMarketIndex]; const data = { name: vaultName, + delegate: vaultDetails.delegate.toBase58(), address: vaultPublicKey.toBase58(), marketName: `${spotToken.symbol}-SPOT`, + balance: `${vaultBalance} ${spotToken.symbol}`, redeemPeriod: vaultDetails.redeemPeriod.toNumber(), maxTokens: vaultDetails.maxTokens.div(spotToken.precision).toNumber(), minDepositAmount: vaultDetails.minDepositAmount @@ -340,7 +397,7 @@ export async function requestWithdrawalFromVault( return await vaultClient.managerRequestWithdraw( vaultPublicKey, new BN(amount.toFixed(0)), - WithdrawUnit.SHARES, + WithdrawUnit.TOKEN, ); } @@ -349,7 +406,7 @@ export async function requestWithdrawalFromVault( const tx = await vaultClient.requestWithdraw( vaultDepositor, new BN(amount.toFixed(0)), - WithdrawUnit.SHARES, + WithdrawUnit.TOKEN, ); await cleanUp(); @@ -406,7 +463,7 @@ async function getIsOwned(agent: SolanaAgentKit, vault: string) { const { vaultClient, cleanUp } = await initClients(agent); const vaultPublicKey = new PublicKey(vault); const vaultDetails = await vaultClient.getVault(vaultPublicKey); - const isOwned = vaultDetails.delegate.equals(agent.wallet.publicKey); + const isOwned = vaultDetails.manager.equals(agent.wallet.publicKey); await cleanUp(); @@ -461,10 +518,13 @@ export async function tradeDriftVault( price?: number, ) { try { - const { driftClient, vaultClient, cleanUp } = await initClients(agent); - const [isOwned, vaultDetails, driftLookupTableAccount] = await Promise.all([ + const { driftClient, cleanUp } = await initClients(agent, { + authority: new PublicKey(vault), + activeSubAccountId: 0, + subAccountIds: [0], + }); + const [isOwned, driftLookupTableAccount] = await Promise.all([ getIsOwned(agent, vault), - vaultClient.getVault(new PublicKey(vault)), driftClient.fetchMarketLookupTableAccount(), ]); @@ -474,34 +534,11 @@ export async function tradeDriftVault( ); } - driftClient.authority = new PublicKey(vault); - driftClient.activeSubAccountId = 0; - vaultClient.driftClient = driftClient; - const usdcSpotMarket = driftClient.getSpotMarketAccount(0); if (!usdcSpotMarket) { throw new Error("USDC-SPOT market not found"); } - const usdcPrecision = TEN.pow(new BN(usdcSpotMarket.decimals)); - const vaultWithdrawalsRequested = convertToNumber( - vaultDetails.totalWithdrawRequested, - usdcPrecision, - ); - // this is actually the authority provided - const user = driftClient.getUser(); - const currentVaultBalance = - convertToNumber(user.getNetSpotMarketValue(), QUOTE_PRECISION) + - convertToNumber(user.getUnrealizedPNL(true), QUOTE_PRECISION); - const availableBalanceInUSD = - currentVaultBalance - vaultWithdrawalsRequested; - - if (amount > availableBalanceInUSD) { - throw new Error( - "Insufficient balance: You don't have enough balance to make this trade", - ); - } - const perpMarketIndexAndType = getMarketIndexAndType( `${symbol.toUpperCase()}-PERP`, ); @@ -569,12 +606,14 @@ export async function tradeDriftVault( instructions.push(instruction); } + const latestBlockhash = await driftClient.connection.getLatestBlockhash(); const tx = await driftClient.txSender.sendVersionedTransaction( await driftClient.txSender.getVersionedTransaction( instructions, [driftLookupTableAccount], [], driftClient.opts, + latestBlockhash, ), );