feat: add drift functions to agent class

This commit is contained in:
michaelessiet
2025-01-14 12:43:28 +01:00
parent 69cfcd90d2
commit f0d84f6924
8 changed files with 223 additions and 36 deletions

View File

@@ -90,7 +90,9 @@ const createDriftVaultAction: Action = {
return {
status: "success",
message: "Drift vault created successfully",
message:
"Drift vault created successfully. Please note down the name of your vault as it is unique and was used to derive your vault address",
vaultName: input.name,
signature: tx,
};
} catch (e) {

View File

@@ -0,0 +1,46 @@
import { z } from "zod";
import type { Action } from "../../types";
import { getVaultAddress } from "../../tools";
const deriveDriftVaultAddressAction: Action = {
name: "DERIVE_DRIFT_VAULT_ADDRESS_ACTION",
similes: ["derive drift vault address", "get drift vault address"],
description: "Derive a drift vault address from the vaults name",
examples: [
[
{
input: {
name: "My Drift Vault",
},
output: {
status: "success",
message: "Vault address derived successfully",
address: "2nFeP7taii3wGVgrWk4YiLMPmhtu3Zg9iXCUu4zGBD",
},
explanation: "Derive a drift vault address",
},
],
],
schema: z.object({
name: z.string().describe("The name of the vault to derive the address of"),
}),
handler: async (agent, input) => {
try {
const address = await getVaultAddress(agent, input.name as string);
return {
status: "success",
message: "Vault address derived successfully",
address,
};
} catch (e) {
return {
status: "error",
// @ts-expect-error - error message
message: `Failed to derive vault address: ${e.message}`,
};
}
},
};
export default deriveDriftVaultAddressAction;

View File

@@ -66,7 +66,7 @@ const tradeDelegatedDriftVaultAction: Action = {
vaultAddress: z.string().describe("Address of the Drift vault to trade in"),
amount: z.number().positive().describe("Amount to trade"),
symbol: z.string().describe("Symbol of the token to trade"),
action: z.enum(["buy", "sell"]).describe("Trade action - buy or sell"),
action: z.enum(["long", "short"]).describe("Trade action - long or short"),
type: z.enum(["market", "limit"]).describe("Trade type - market or limit"),
price: z.number().positive().optional().describe("Price for limit order"),
}),
@@ -76,7 +76,7 @@ const tradeDelegatedDriftVaultAction: Action = {
vaultAddress: input.vaultAddress as string,
amount: input.amount as number,
symbol: input.symbol as string,
action: input.action as "buy" | "sell",
action: input.action as "long" | "short",
type: input.type as "market" | "limit",
price: input.price as number | undefined,
};

View File

@@ -13,7 +13,7 @@ const vaultInfoAction: Action = {
[
{
input: {
vaultAddress: "2nFeP7taii",
vaultName: "test-vault",
},
output: {
status: "success",
@@ -35,35 +35,16 @@ const vaultInfoAction: Action = {
],
],
schema: z.object({
vaultAddress: z.string(),
vaultName: z.string(),
}),
handler: async (agent: SolanaAgentKit, input) => {
try {
const vaultInfo = await getVaultInfo(agent, input.vaultAddress as string);
const spotToken = MainnetSpotMarkets[vaultInfo.spotMarketIndex];
const data = {
name: decodeName(vaultInfo.name),
marketName: `${spotToken.symbol}-SPOT`,
redeemPeriod: vaultInfo.redeemPeriod.toNumber(),
maxTokens: vaultInfo.maxTokens.div(spotToken.precision).toNumber(),
minDepositAmount: vaultInfo.minDepositAmount
.div(spotToken.precision)
.toNumber(),
managementFee:
(vaultInfo.managementFee.toNumber() /
PERCENTAGE_PRECISION.toNumber()) *
100,
profitShare:
(vaultInfo.profitShare / PERCENTAGE_PRECISION.toNumber()) * 100,
hurdleRate:
(vaultInfo.hurdleRate / PERCENTAGE_PRECISION.toNumber()) * 100,
permissioned: vaultInfo.permissioned,
};
const vaultInfo = await getVaultInfo(agent, input.vaultName as string);
return {
status: "success",
message: "Vault info retrieved successfully",
data,
data: vaultInfo,
};
} catch (e) {
return {

View File

@@ -56,6 +56,7 @@ import doesUserHaveDriftAccountAction from "./drift/doesUserHaveDriftAccount";
import depositToDriftUserAccountAction from "./drift/depositToDriftUserAccount";
import withdrawFromDriftAccountAction from "./drift/withdrawFromDriftAccount";
import driftUserAccountInfoAction from "./drift/driftUserAccountInfo";
import deriveDriftVaultAddressAction from "./drift/deriveVaultAddress";
export const ACTIONS = {
WALLET_ADDRESS_ACTION: getWalletAddressAction,
@@ -117,6 +118,7 @@ export const ACTIONS = {
DEPOSIT_TO_DRIFT_USER_ACCOUNT_ACTION: depositToDriftUserAccountAction,
WITHDRAW_OR_BORROW_FROM_DRIFT_ACCOUNT_ACTION: withdrawFromDriftAccountAction,
DRIFT_USER_ACCOUNT_INFO_ACTION: driftUserAccountInfoAction,
DERIVE_DRIFT_VAULT_ADDRESS_ACTION: deriveDriftVaultAddressAction,
};
export type { Action, ActionExample, Handler } from "../types/action";

View File

@@ -82,6 +82,20 @@ import {
getHeliusWebhook,
create_HeliusWebhook,
deleteHeliusWebhook,
createDriftUserAccount,
createVault,
depositIntoVault,
depositToDriftUserAccount,
getVaultAddress,
doesUserHaveDriftAccount,
driftUserAccountInfo,
requestWithdrawalFromVault,
tradeDriftVault,
driftPerpTrade,
updateVault,
getVaultInfo,
withdrawFromDriftUserAccount,
withdrawFromDriftVault,
} from "../tools";
import {
Config,
@@ -694,4 +708,99 @@ export class SolanaAgentKit {
async deleteWebhook(webhookID: string): Promise<any> {
return deleteHeliusWebhook(this, webhookID);
}
async createDriftUserAccount(depositAmount: number, depositSymbol: string) {
return await createDriftUserAccount(this, depositAmount, depositSymbol);
}
async createDriftVault(params: {
name: string;
marketName: `${string}-${string}`;
redeemPeriod: number;
maxTokens: number;
minDepositAmount: number;
managementFee: number;
profitShare: number;
hurdleRate?: number;
permissioned?: boolean;
}) {
return await createVault(this, params);
}
async depositIntoDriftVault(amount: number, vault: string) {
return await depositIntoVault(this, amount, vault);
}
async depositToDriftUserAccount(
amount: number,
symbol: string,
isRepayment?: boolean,
) {
return await depositToDriftUserAccount(this, amount, symbol, isRepayment);
}
async deriveDriftVaultAddress(name: string) {
return await getVaultAddress(this, name);
}
async doesUserHaveDriftAccount() {
return await doesUserHaveDriftAccount(this);
}
async driftUserAccountInfo() {
return await driftUserAccountInfo(this);
}
async requestWithdrawalFromDriftVault(amount: number, vault: string) {
return await requestWithdrawalFromVault(this, amount, vault);
}
async tradeUsingDelegatedDriftVault(
vault: string,
amount: number,
symbol: string,
action: "long" | "short",
type: "market" | "limit",
price?: number,
) {
return await tradeDriftVault(
this,
vault,
amount,
symbol,
action,
type,
price,
);
}
async tradeUsingDriftPerpAccount(
amount: number,
symbol: string,
action: "long" | "short",
type: "market" | "limit",
price?: number,
) {
return await driftPerpTrade(this, { action, amount, symbol, type, price });
}
async updateDriftVault(
vaultAddress: string,
params: {
name: string;
marketName: `${string}-${string}`;
redeemPeriod: number;
maxTokens: number;
minDepositAmount: number;
managementFee: number;
profitShare: number;
hurdleRate?: number;
permissioned?: boolean;
},
) {
return await updateVault(this, vaultAddress, params);
}
async getDriftVaultInfo(vaultName: string) {
return await getVaultInfo(this, vaultName);
}
async withdrawFromDriftAccount(
amount: number,
symbol: string,
isBorrow?: boolean,
) {
return await withdrawFromDriftUserAccount(this, amount, symbol, isBorrow);
}
async withdrawFromDriftVault(vault: string) {
return await withdrawFromDriftVault(this, vault);
}
}

View File

@@ -276,7 +276,7 @@ export async function driftPerpTrade(
symbol: string;
action: "long" | "short";
type: "market" | "limit";
price?: number;
price?: number | undefined;
},
) {
try {

View File

@@ -18,6 +18,7 @@ import {
import {
WithdrawUnit,
encodeName,
getVaultAddressSync,
getVaultDepositorAddressSync,
} from "@drift-labs/vaults-sdk";
import {
@@ -233,18 +234,39 @@ export async function updateVault(
}
}
export async function getVaultInfo(
agent: SolanaAgentKit,
vaultAddress: string,
) {
export async function getVaultInfo(agent: SolanaAgentKit, vaultName: string) {
try {
const { vaultClient, cleanUp } = await initClients(agent);
const vaultPublicKey = new PublicKey(vaultAddress);
const vaultPublicKey = getVaultAddressSync(
vaultClient.program.programId,
encodeName(vaultName),
);
const vaultDetails = await vaultClient.getVault(vaultPublicKey);
await cleanUp();
return vaultDetails;
const spotToken = MainnetSpotMarkets[vaultDetails.spotMarketIndex];
const data = {
name: vaultName,
address: vaultPublicKey.toBase58(),
marketName: `${spotToken.symbol}-SPOT`,
redeemPeriod: vaultDetails.redeemPeriod.toNumber(),
maxTokens: vaultDetails.maxTokens.div(spotToken.precision).toNumber(),
minDepositAmount: vaultDetails.minDepositAmount
.div(spotToken.precision)
.toNumber(),
managementFee:
(vaultDetails.managementFee.toNumber() /
PERCENTAGE_PRECISION.toNumber()) *
100,
profitShare:
(vaultDetails.profitShare / PERCENTAGE_PRECISION.toNumber()) * 100,
hurdleRate:
(vaultDetails.hurdleRate / PERCENTAGE_PRECISION.toNumber()) * 100,
permissioned: vaultDetails.permissioned,
};
return data;
} catch (e) {
// @ts-expect-error - error message is a string
throw new Error(`Failed to get vault info: ${e.message}`);
@@ -395,6 +417,31 @@ async function getIsOwned(agent: SolanaAgentKit, vault: string) {
}
}
/**
* Get a vaults address using the vault's name
* @param agent
* @param name
*/
export async function getVaultAddress(agent: SolanaAgentKit, name: string) {
const encodedName = encodeName(name);
try {
const { vaultClient, cleanUp } = await initClients(agent);
const vaultAddress = getVaultAddressSync(
vaultClient.program.programId,
encodedName,
);
await cleanUp();
return vaultAddress;
} catch (e) {
throw new Error(
// @ts-expect-error - error message is a string
`Failed to get vault address: ${e.message}`,
);
}
}
/**
Carry out a trade with a delegated vault
@param agent SolanaAgentKit instance
@@ -409,7 +456,7 @@ export async function tradeDriftVault(
vault: string,
amount: number,
symbol: string,
action: "buy" | "sell",
action: "long" | "short",
type: "market" | "limit",
price?: number,
) {
@@ -494,7 +541,7 @@ export async function tradeDriftVault(
marketType: MarketType.PERP,
baseAssetAmount: numberToSafeBN(baseAmount, BASE_PRECISION),
direction:
action === "buy"
action === "long"
? PositionDirection.LONG
: PositionDirection.SHORT,
marketIndex: perpMarketAccount.marketIndex,
@@ -512,7 +559,7 @@ export async function tradeDriftVault(
marketType: MarketType.PERP,
baseAssetAmount: numberToSafeBN(baseAmount, BASE_PRECISION),
direction:
action === "buy"
action === "long"
? PositionDirection.LONG
: PositionDirection.SHORT,
marketIndex: perpMarketAccount.marketIndex,