From f0d84f692481b320fca2b2d8c91e7fe7fbd6b969 Mon Sep 17 00:00:00 2001 From: michaelessiet Date: Tue, 14 Jan 2025 12:43:28 +0100 Subject: [PATCH] feat: add drift functions to agent class --- src/actions/drift/createVault.ts | 4 +- src/actions/drift/deriveVaultAddress.ts | 46 ++++++++ src/actions/drift/tradeDelegatedDriftVault.ts | 4 +- src/actions/drift/vaultInfo.ts | 27 +---- src/actions/index.ts | 2 + src/agent/index.ts | 109 ++++++++++++++++++ src/tools/drift/drift.ts | 2 +- src/tools/drift/drift_vault.ts | 65 +++++++++-- 8 files changed, 223 insertions(+), 36 deletions(-) create mode 100644 src/actions/drift/deriveVaultAddress.ts diff --git a/src/actions/drift/createVault.ts b/src/actions/drift/createVault.ts index 1d73fc9..26ac59b 100644 --- a/src/actions/drift/createVault.ts +++ b/src/actions/drift/createVault.ts @@ -90,7 +90,9 @@ const createDriftVaultAction: Action = { return { status: "success", - message: "Drift vault created successfully", + message: + "Drift vault created successfully. Please note down the name of your vault as it is unique and was used to derive your vault address", + vaultName: input.name, signature: tx, }; } catch (e) { diff --git a/src/actions/drift/deriveVaultAddress.ts b/src/actions/drift/deriveVaultAddress.ts new file mode 100644 index 0000000..63ce7ee --- /dev/null +++ b/src/actions/drift/deriveVaultAddress.ts @@ -0,0 +1,46 @@ +import { z } from "zod"; +import type { Action } from "../../types"; +import { getVaultAddress } from "../../tools"; + +const deriveDriftVaultAddressAction: Action = { + name: "DERIVE_DRIFT_VAULT_ADDRESS_ACTION", + similes: ["derive drift vault address", "get drift vault address"], + description: "Derive a drift vault address from the vaults name", + examples: [ + [ + { + input: { + name: "My Drift Vault", + }, + output: { + status: "success", + message: "Vault address derived successfully", + address: "2nFeP7taii3wGVgrWk4YiLMPmhtu3Zg9iXCUu4zGBD", + }, + explanation: "Derive a drift vault address", + }, + ], + ], + schema: z.object({ + name: z.string().describe("The name of the vault to derive the address of"), + }), + handler: async (agent, input) => { + try { + const address = await getVaultAddress(agent, input.name as string); + + return { + status: "success", + message: "Vault address derived successfully", + address, + }; + } catch (e) { + return { + status: "error", + // @ts-expect-error - error message + message: `Failed to derive vault address: ${e.message}`, + }; + } + }, +}; + +export default deriveDriftVaultAddressAction; diff --git a/src/actions/drift/tradeDelegatedDriftVault.ts b/src/actions/drift/tradeDelegatedDriftVault.ts index a51e143..c85d8b3 100644 --- a/src/actions/drift/tradeDelegatedDriftVault.ts +++ b/src/actions/drift/tradeDelegatedDriftVault.ts @@ -66,7 +66,7 @@ const tradeDelegatedDriftVaultAction: Action = { vaultAddress: z.string().describe("Address of the Drift vault to trade in"), amount: z.number().positive().describe("Amount to trade"), symbol: z.string().describe("Symbol of the token to trade"), - action: z.enum(["buy", "sell"]).describe("Trade action - buy or sell"), + action: z.enum(["long", "short"]).describe("Trade action - long or short"), type: z.enum(["market", "limit"]).describe("Trade type - market or limit"), price: z.number().positive().optional().describe("Price for limit order"), }), @@ -76,7 +76,7 @@ const tradeDelegatedDriftVaultAction: Action = { vaultAddress: input.vaultAddress as string, amount: input.amount as number, symbol: input.symbol as string, - action: input.action as "buy" | "sell", + action: input.action as "long" | "short", type: input.type as "market" | "limit", price: input.price as number | undefined, }; diff --git a/src/actions/drift/vaultInfo.ts b/src/actions/drift/vaultInfo.ts index c354ae5..279f78b 100644 --- a/src/actions/drift/vaultInfo.ts +++ b/src/actions/drift/vaultInfo.ts @@ -13,7 +13,7 @@ const vaultInfoAction: Action = { [ { input: { - vaultAddress: "2nFeP7taii", + vaultName: "test-vault", }, output: { status: "success", @@ -35,35 +35,16 @@ const vaultInfoAction: Action = { ], ], schema: z.object({ - vaultAddress: z.string(), + vaultName: z.string(), }), handler: async (agent: SolanaAgentKit, input) => { try { - const vaultInfo = await getVaultInfo(agent, input.vaultAddress as string); - const spotToken = MainnetSpotMarkets[vaultInfo.spotMarketIndex]; - const data = { - name: decodeName(vaultInfo.name), - marketName: `${spotToken.symbol}-SPOT`, - redeemPeriod: vaultInfo.redeemPeriod.toNumber(), - maxTokens: vaultInfo.maxTokens.div(spotToken.precision).toNumber(), - minDepositAmount: vaultInfo.minDepositAmount - .div(spotToken.precision) - .toNumber(), - managementFee: - (vaultInfo.managementFee.toNumber() / - PERCENTAGE_PRECISION.toNumber()) * - 100, - profitShare: - (vaultInfo.profitShare / PERCENTAGE_PRECISION.toNumber()) * 100, - hurdleRate: - (vaultInfo.hurdleRate / PERCENTAGE_PRECISION.toNumber()) * 100, - permissioned: vaultInfo.permissioned, - }; + const vaultInfo = await getVaultInfo(agent, input.vaultName as string); return { status: "success", message: "Vault info retrieved successfully", - data, + data: vaultInfo, }; } catch (e) { return { diff --git a/src/actions/index.ts b/src/actions/index.ts index 3cbc7df..4664ded 100644 --- a/src/actions/index.ts +++ b/src/actions/index.ts @@ -56,6 +56,7 @@ import doesUserHaveDriftAccountAction from "./drift/doesUserHaveDriftAccount"; import depositToDriftUserAccountAction from "./drift/depositToDriftUserAccount"; import withdrawFromDriftAccountAction from "./drift/withdrawFromDriftAccount"; import driftUserAccountInfoAction from "./drift/driftUserAccountInfo"; +import deriveDriftVaultAddressAction from "./drift/deriveVaultAddress"; export const ACTIONS = { WALLET_ADDRESS_ACTION: getWalletAddressAction, @@ -117,6 +118,7 @@ export const ACTIONS = { DEPOSIT_TO_DRIFT_USER_ACCOUNT_ACTION: depositToDriftUserAccountAction, WITHDRAW_OR_BORROW_FROM_DRIFT_ACCOUNT_ACTION: withdrawFromDriftAccountAction, DRIFT_USER_ACCOUNT_INFO_ACTION: driftUserAccountInfoAction, + DERIVE_DRIFT_VAULT_ADDRESS_ACTION: deriveDriftVaultAddressAction, }; export type { Action, ActionExample, Handler } from "../types/action"; diff --git a/src/agent/index.ts b/src/agent/index.ts index ff6352c..d889897 100644 --- a/src/agent/index.ts +++ b/src/agent/index.ts @@ -82,6 +82,20 @@ import { getHeliusWebhook, create_HeliusWebhook, deleteHeliusWebhook, + createDriftUserAccount, + createVault, + depositIntoVault, + depositToDriftUserAccount, + getVaultAddress, + doesUserHaveDriftAccount, + driftUserAccountInfo, + requestWithdrawalFromVault, + tradeDriftVault, + driftPerpTrade, + updateVault, + getVaultInfo, + withdrawFromDriftUserAccount, + withdrawFromDriftVault, } from "../tools"; import { Config, @@ -694,4 +708,99 @@ export class SolanaAgentKit { async deleteWebhook(webhookID: string): Promise { return deleteHeliusWebhook(this, webhookID); } + + async createDriftUserAccount(depositAmount: number, depositSymbol: string) { + return await createDriftUserAccount(this, depositAmount, depositSymbol); + } + async createDriftVault(params: { + name: string; + marketName: `${string}-${string}`; + redeemPeriod: number; + maxTokens: number; + minDepositAmount: number; + managementFee: number; + profitShare: number; + hurdleRate?: number; + permissioned?: boolean; + }) { + return await createVault(this, params); + } + async depositIntoDriftVault(amount: number, vault: string) { + return await depositIntoVault(this, amount, vault); + } + async depositToDriftUserAccount( + amount: number, + symbol: string, + isRepayment?: boolean, + ) { + return await depositToDriftUserAccount(this, amount, symbol, isRepayment); + } + async deriveDriftVaultAddress(name: string) { + return await getVaultAddress(this, name); + } + async doesUserHaveDriftAccount() { + return await doesUserHaveDriftAccount(this); + } + async driftUserAccountInfo() { + return await driftUserAccountInfo(this); + } + async requestWithdrawalFromDriftVault(amount: number, vault: string) { + return await requestWithdrawalFromVault(this, amount, vault); + } + async tradeUsingDelegatedDriftVault( + vault: string, + amount: number, + symbol: string, + action: "long" | "short", + type: "market" | "limit", + price?: number, + ) { + return await tradeDriftVault( + this, + vault, + amount, + symbol, + action, + type, + price, + ); + } + async tradeUsingDriftPerpAccount( + amount: number, + symbol: string, + action: "long" | "short", + type: "market" | "limit", + price?: number, + ) { + return await driftPerpTrade(this, { action, amount, symbol, type, price }); + } + async updateDriftVault( + vaultAddress: string, + params: { + name: string; + marketName: `${string}-${string}`; + redeemPeriod: number; + maxTokens: number; + minDepositAmount: number; + managementFee: number; + profitShare: number; + hurdleRate?: number; + permissioned?: boolean; + }, + ) { + return await updateVault(this, vaultAddress, params); + } + async getDriftVaultInfo(vaultName: string) { + return await getVaultInfo(this, vaultName); + } + async withdrawFromDriftAccount( + amount: number, + symbol: string, + isBorrow?: boolean, + ) { + return await withdrawFromDriftUserAccount(this, amount, symbol, isBorrow); + } + async withdrawFromDriftVault(vault: string) { + return await withdrawFromDriftVault(this, vault); + } } diff --git a/src/tools/drift/drift.ts b/src/tools/drift/drift.ts index 525a379..d112fca 100644 --- a/src/tools/drift/drift.ts +++ b/src/tools/drift/drift.ts @@ -276,7 +276,7 @@ export async function driftPerpTrade( symbol: string; action: "long" | "short"; type: "market" | "limit"; - price?: number; + price?: number | undefined; }, ) { try { diff --git a/src/tools/drift/drift_vault.ts b/src/tools/drift/drift_vault.ts index 85e803d..3bfa182 100644 --- a/src/tools/drift/drift_vault.ts +++ b/src/tools/drift/drift_vault.ts @@ -18,6 +18,7 @@ import { import { WithdrawUnit, encodeName, + getVaultAddressSync, getVaultDepositorAddressSync, } from "@drift-labs/vaults-sdk"; import { @@ -233,18 +234,39 @@ export async function updateVault( } } -export async function getVaultInfo( - agent: SolanaAgentKit, - vaultAddress: string, -) { +export async function getVaultInfo(agent: SolanaAgentKit, vaultName: string) { try { const { vaultClient, cleanUp } = await initClients(agent); - const vaultPublicKey = new PublicKey(vaultAddress); + const vaultPublicKey = getVaultAddressSync( + vaultClient.program.programId, + encodeName(vaultName), + ); const vaultDetails = await vaultClient.getVault(vaultPublicKey); await cleanUp(); - return vaultDetails; + const spotToken = MainnetSpotMarkets[vaultDetails.spotMarketIndex]; + const data = { + name: vaultName, + address: vaultPublicKey.toBase58(), + marketName: `${spotToken.symbol}-SPOT`, + redeemPeriod: vaultDetails.redeemPeriod.toNumber(), + maxTokens: vaultDetails.maxTokens.div(spotToken.precision).toNumber(), + minDepositAmount: vaultDetails.minDepositAmount + .div(spotToken.precision) + .toNumber(), + managementFee: + (vaultDetails.managementFee.toNumber() / + PERCENTAGE_PRECISION.toNumber()) * + 100, + profitShare: + (vaultDetails.profitShare / PERCENTAGE_PRECISION.toNumber()) * 100, + hurdleRate: + (vaultDetails.hurdleRate / PERCENTAGE_PRECISION.toNumber()) * 100, + permissioned: vaultDetails.permissioned, + }; + + return data; } catch (e) { // @ts-expect-error - error message is a string throw new Error(`Failed to get vault info: ${e.message}`); @@ -395,6 +417,31 @@ async function getIsOwned(agent: SolanaAgentKit, vault: string) { } } +/** + * Get a vaults address using the vault's name + * @param agent + * @param name + */ +export async function getVaultAddress(agent: SolanaAgentKit, name: string) { + const encodedName = encodeName(name); + + try { + const { vaultClient, cleanUp } = await initClients(agent); + const vaultAddress = getVaultAddressSync( + vaultClient.program.programId, + encodedName, + ); + + await cleanUp(); + return vaultAddress; + } catch (e) { + throw new Error( + // @ts-expect-error - error message is a string + `Failed to get vault address: ${e.message}`, + ); + } +} + /** Carry out a trade with a delegated vault @param agent SolanaAgentKit instance @@ -409,7 +456,7 @@ export async function tradeDriftVault( vault: string, amount: number, symbol: string, - action: "buy" | "sell", + action: "long" | "short", type: "market" | "limit", price?: number, ) { @@ -494,7 +541,7 @@ export async function tradeDriftVault( marketType: MarketType.PERP, baseAssetAmount: numberToSafeBN(baseAmount, BASE_PRECISION), direction: - action === "buy" + action === "long" ? PositionDirection.LONG : PositionDirection.SHORT, marketIndex: perpMarketAccount.marketIndex, @@ -512,7 +559,7 @@ export async function tradeDriftVault( marketType: MarketType.PERP, baseAssetAmount: numberToSafeBN(baseAmount, BASE_PRECISION), direction: - action === "buy" + action === "long" ? PositionDirection.LONG : PositionDirection.SHORT, marketIndex: perpMarketAccount.marketIndex,