feat: drift perp trade, account creation, withdraw, and deposit

This commit is contained in:
michaelessiet
2025-01-12 01:33:11 +01:00
parent 8bef717eb9
commit 25f0f503cb
10 changed files with 738 additions and 77 deletions

344
src/tools/drift.ts Normal file
View File

@@ -0,0 +1,344 @@
import {
BASE_PRECISION,
convertToNumber,
DRIFT_PROGRAM_ID,
DriftClient,
FastSingleTxSender,
getLimitOrderParams,
getMarketOrderParams,
getUserAccountPublicKeySync,
MainnetSpotMarkets,
numberToSafeBN,
PositionDirection,
PostOnlyParams,
PRICE_PRECISION,
User,
type IWallet,
} from "@drift-labs/sdk";
import type { SolanaAgentKit } from "../agent";
import * as anchor from "@coral-xyz/anchor";
import { IDL, VAULT_PROGRAM_ID, VaultClient } from "@drift-labs/vaults-sdk";
import { BN } from "bn.js";
import { getAssociatedTokenAddressSync } from "@solana/spl-token";
import { PublicKey } from "@solana/web3.js";
export async function initClients(agent: SolanaAgentKit) {
const wallet: IWallet = {
publicKey: agent.wallet.publicKey,
payer: agent.wallet,
signAllTransactions: async (txs) => {
for (const tx of txs) {
tx.sign(agent.wallet);
}
return txs;
},
signTransaction: async (tx) => {
tx.sign(agent.wallet);
return tx;
},
};
const driftClient = new DriftClient({
connection: agent.connection,
wallet,
env: "mainnet-beta",
txSender: new FastSingleTxSender({
connection: agent.connection,
wallet,
timeout: 30000,
blockhashRefreshInterval: 1000,
opts: {
commitment: agent.connection.commitment ?? "confirmed",
skipPreflight: false,
preflightCommitment: agent.connection.commitment ?? "confirmed",
},
}),
});
const vaultProgram = new anchor.Program(
IDL,
VAULT_PROGRAM_ID,
driftClient.provider,
);
const vaultClient = new VaultClient({
driftClient,
// @ts-expect-error - type mismatch due to different dep versions
program: vaultProgram,
cliMode: false,
});
await driftClient.subscribe();
async function cleanUp() {
await driftClient.unsubscribe();
}
return { driftClient, vaultClient, cleanUp };
}
/**
* Create a drift user account provided an amount
* @param amount amount of the token to deposit
* @param symbol symbol of the token to deposit
*/
export async function createDriftUserAccount(
agent: SolanaAgentKit,
amount: number,
symbol: string,
) {
try {
const { driftClient, cleanUp } = await initClients(agent);
const user = new User({
driftClient,
userAccountPublicKey: getUserAccountPublicKeySync(
new PublicKey(DRIFT_PROGRAM_ID),
agent.wallet.publicKey,
),
});
const userAccountExists = await user.exists();
const token = MainnetSpotMarkets.find(
(v) => v.symbol === symbol.toUpperCase(),
);
if (!token) {
throw new Error(`Token with symbol ${symbol} not found`);
}
if (!userAccountExists) {
const depositAmount = new BN(amount).mul(token.precision);
const [txSignature, account] =
await driftClient.initializeUserAccountAndDepositCollateral(
depositAmount,
getAssociatedTokenAddressSync(token.mint, agent.wallet.publicKey),
);
await cleanUp();
return { txSignature, account };
}
await cleanUp();
return {
message: "User account already exists",
account: user.userAccountPublicKey,
};
} catch (e) {
// @ts-expect-error - error message is a string
throw new Error(`Failed to create user account: ${e.message}`);
}
}
/**
* Deposit to your drift user account
* @param agent
* @param amount
* @param symbol
* @param address
* @returns
*/
export async function depositToDriftUserAccount(
agent: SolanaAgentKit,
amount: number,
symbol: string,
address?: string,
) {
try {
const { driftClient, cleanUp } = await initClients(agent);
const publicKey = address ? new PublicKey(address) : agent.wallet.publicKey;
const user = new User({
driftClient,
userAccountPublicKey: getUserAccountPublicKeySync(
new PublicKey(DRIFT_PROGRAM_ID),
publicKey,
),
});
const userAccountExists = await user.exists();
const token = MainnetSpotMarkets.find(
(v) => v.symbol === symbol.toUpperCase(),
);
if (!token) {
throw new Error(`Token with symbol ${symbol} not found`);
}
if (!userAccountExists) {
throw new Error("You need to create a Drift user account first.");
}
const depositAmount = new BN(amount).mul(token.precision);
const txSignature = await driftClient.deposit(
depositAmount,
token.marketIndex,
getAssociatedTokenAddressSync(token.mint, publicKey),
);
await cleanUp();
return txSignature;
} catch (e) {
// @ts-expect-error - error message is a string
throw new Error(`Failed to deposit to user account: ${e.message}`);
}
}
export async function withdrawFromDriftUserAccount(
agent: SolanaAgentKit,
amount: number,
symbol: string,
) {
try {
const { driftClient, cleanUp } = await initClients(agent);
const user = new User({
driftClient,
userAccountPublicKey: getUserAccountPublicKeySync(
new PublicKey(DRIFT_PROGRAM_ID),
agent.wallet.publicKey,
),
});
const userAccountExists = await user.exists();
if (!userAccountExists) {
throw new Error("You need to create a Drift user account first.");
}
const token = MainnetSpotMarkets.find(
(v) => v.symbol === symbol.toUpperCase(),
);
if (!token) {
throw new Error(`Token with symbol ${symbol} not found`);
}
const withdrawAmount = numberToSafeBN(amount, token.precision);
const txSignature = await driftClient.withdraw(
withdrawAmount,
token.marketIndex,
getAssociatedTokenAddressSync(token.mint, agent.wallet.publicKey),
);
await cleanUp();
return txSignature;
} catch (e) {
// @ts-expect-error - error message is a string
throw new Error(`Failed to withdraw from user account: ${e.message}`);
}
}
/**
* Open a perpetual trade on drift
* @param agent
* @param params.amount
* @param params.symbol
* @param params.action
* @param params.type
* @param params.price this should only be supplied if type is limit
* @param params.reduceOnly
*/
export async function driftPerpTrade(
agent: SolanaAgentKit,
params: {
amount: number;
symbol: string;
action: "long" | "short";
type: "market" | "limit";
price?: number;
},
) {
try {
const { driftClient, cleanUp } = await initClients(agent);
const user = new User({
driftClient,
userAccountPublicKey: getUserAccountPublicKeySync(
new PublicKey(DRIFT_PROGRAM_ID),
agent.wallet.publicKey,
),
});
const userAccountExists = await user.exists();
if (!userAccountExists) {
throw new Error("You need to create a Drift user account first.");
}
const market = driftClient.getMarketIndexAndType(
`${params.symbol.toUpperCase()}-PERP`,
);
if (!market) {
throw new Error(`Token with symbol ${params.symbol} not found`);
}
const baseAssetPrice = driftClient.getOracleDataForPerpMarket(
market.marketIndex,
);
const convertedAmount =
params.amount / convertToNumber(baseAssetPrice.price, PRICE_PRECISION);
let signature: anchor.web3.TransactionSignature;
if (params.type === "limit") {
if (!params.price) {
throw new Error("Price is required for limit orders");
}
signature = await driftClient.placePerpOrder(
getLimitOrderParams({
baseAssetAmount: numberToSafeBN(convertedAmount, BASE_PRECISION),
reduceOnly: false,
direction:
params.action === "long"
? PositionDirection.LONG
: PositionDirection.SHORT,
marketIndex: market.marketIndex,
price: numberToSafeBN(params.price, PRICE_PRECISION),
postOnly: PostOnlyParams.SLIDE,
}),
);
} else {
signature = await driftClient.placePerpOrder(
getMarketOrderParams({
baseAssetAmount: numberToSafeBN(convertedAmount, BASE_PRECISION),
reduceOnly: false,
direction:
params.action === "long"
? PositionDirection.LONG
: PositionDirection.SHORT,
marketIndex: market.marketIndex,
}),
);
}
if (!signature) {
throw new Error("Failed to place order. Please make sure ");
}
await cleanUp();
return signature;
} catch (e) {
// @ts-expect-error - error message is a string
throw new Error(`Failed to place order: ${e.message}`);
}
}
/**
* Check if a user has a drift account
* @param agent
*/
export async function doesUserHaveDriftAccount(agent: SolanaAgentKit) {
try {
const { driftClient, cleanUp } = await initClients(agent);
const user = new User({
driftClient,
userAccountPublicKey: getUserAccountPublicKeySync(
new PublicKey(DRIFT_PROGRAM_ID),
agent.wallet.publicKey,
),
});
user.getActivePerpPositions();
const userAccountExists = await user.exists();
await cleanUp();
return {
hasAccount: userAccountExists,
account: user.userAccountPublicKey,
};
} catch (e) {
// @ts-expect-error - error message is a string
throw new Error(`Failed to check user account: ${e.message}`);
}
}

View File

@@ -1,8 +1,6 @@
import {
BASE_PRECISION,
convertToNumber,
DriftClient,
FastSingleTxSender,
getLimitOrderParams,
getMarketOrderParams,
getOrderParams,
@@ -16,17 +14,12 @@ import {
PRICE_PRECISION,
QUOTE_PRECISION,
TEN,
type IWallet,
} from "@drift-labs/sdk";
import {
VAULT_PROGRAM_ID,
VaultClient,
IDL,
WithdrawUnit,
encodeName,
getVaultDepositorAddressSync,
} from "@drift-labs/vaults-sdk";
import * as anchor from "@coral-xyz/anchor";
import {
ComputeBudgetProgram,
PublicKey,
@@ -34,6 +27,7 @@ import {
} from "@solana/web3.js";
import type { SolanaAgentKit } from "../agent";
import { BN } from "bn.js";
import { initClients } from "./drift";
export function getMarketIndexAndType(name: `${string}-${string}`) {
const [symbol, type] = name.toUpperCase().split("-");
@@ -53,58 +47,6 @@ export function getMarketIndexAndType(name: `${string}-${string}`) {
return { marketIndex: token.marketIndex, marketType: MarketType.SPOT };
}
async function initClients(agent: SolanaAgentKit) {
const wallet: IWallet = {
publicKey: agent.wallet.publicKey,
payer: agent.wallet,
signAllTransactions: async (txs) => {
for (const tx of txs) {
tx.sign(agent.wallet);
}
return txs;
},
signTransaction: async (tx) => {
tx.sign(agent.wallet);
return tx;
},
};
const driftClient = new DriftClient({
connection: agent.connection,
wallet,
env: "mainnet-beta",
txSender: new FastSingleTxSender({
connection: agent.connection,
wallet,
timeout: 30000,
blockhashRefreshInterval: 1000,
opts: {
commitment: agent.connection.commitment ?? "confirmed",
skipPreflight: false,
preflightCommitment: agent.connection.commitment ?? "confirmed",
},
}),
});
const vaultProgram = new anchor.Program(
IDL,
VAULT_PROGRAM_ID,
driftClient.provider,
);
const vaultClient = new VaultClient({
driftClient,
// @ts-expect-error - type mismatch due to different dep versions
program: vaultProgram,
cliMode: false,
});
await driftClient.subscribe();
async function cleanUp() {
await driftClient.unsubscribe();
}
return { driftClient, vaultClient, cleanUp };
}
async function getOrCreateVaultDepositor(agent: SolanaAgentKit, vault: string) {
const { vaultClient, cleanUp } = await initClients(agent);
const vaultPublicKey = new PublicKey(vault);
@@ -236,27 +178,49 @@ export async function updateVault(
},
) {
try {
const { vaultClient, cleanUp } = await initClients(agent);
const { vaultClient, cleanUp, driftClient } = await initClients(agent);
const vaultPublicKey = new PublicKey(vault);
const vaultDetails = await vaultClient.getVault(vaultPublicKey);
const spotMarket = driftClient.getSpotMarketAccount(
vaultDetails.spotMarketIndex,
);
if (!spotMarket) {
throw new Error("Market not found");
}
const spotPrecision = TEN.pow(new BN(spotMarket.decimals));
const tx = await vaultClient.managerUpdateVault(vaultPublicKey, {
redeemPeriod: new BN(
params.redeemPeriod
? params.redeemPeriod * 86400
: vaultDetails.redeemPeriod,
),
maxTokens: new BN(params.maxTokens ?? vaultDetails.maxTokens),
minDepositAmount: new BN(
params.minDepositAmount ?? vaultDetails.minDepositAmount,
),
managementFee: new BN(params.managementFee ?? vaultDetails.managementFee),
profitShare: new BN(
params.profitShare ?? vaultDetails.profitShare,
).toNumber(),
hurdleRate: new BN(
params.hurdleRate ?? vaultDetails.hurdleRate,
).toNumber(),
maxTokens: params.maxTokens
? numberToSafeBN(params.maxTokens, spotPrecision)
: vaultDetails.maxTokens,
minDepositAmount: params.minDepositAmount
? numberToSafeBN(params.minDepositAmount, spotPrecision)
: vaultDetails.minDepositAmount,
managementFee: params.managementFee
? new BN(params.managementFee)
.mul(PERCENTAGE_PRECISION)
.div(new BN(100))
: vaultDetails.managementFee,
profitShare: params.profitShare
? new BN(params.profitShare)
.mul(PERCENTAGE_PRECISION)
.div(new BN(100))
.toNumber()
: vaultDetails.profitShare,
hurdleRate: params.hurdleRate
? new BN(params.hurdleRate)
.mul(PERCENTAGE_PRECISION)
.div(new BN(100))
.toNumber()
: vaultDetails.hurdleRate,
permissioned: params.permissioned ?? vaultDetails.permissioned,
});
@@ -494,15 +458,18 @@ export async function tradeDriftVault(
const perpMarketIndexAndType = getMarketIndexAndType(
`${symbol.toUpperCase()}-PERP`,
);
const perpMarketAccount = driftClient.getPerpMarketAccount(
perpMarketIndexAndType.marketIndex,
);
if (!perpMarketIndexAndType) {
if (!perpMarketIndexAndType || !perpMarketAccount) {
throw new Error(
"Invalid symbol: Drift doesn't have a market for this token",
);
}
const perpOracle = driftClient.getOracleDataForPerpMarket(
perpMarketIndexAndType.marketIndex,
perpMarketAccount.marketIndex,
);
const oraclePriceNumber = convertToNumber(
perpOracle.price,
@@ -530,7 +497,7 @@ export async function tradeDriftVault(
action === "buy"
? PositionDirection.LONG
: PositionDirection.SHORT,
marketIndex: perpMarketIndexAndType.marketIndex,
marketIndex: perpMarketAccount.marketIndex,
postOnly: PostOnlyParams.SLIDE,
}),
),
@@ -548,7 +515,7 @@ export async function tradeDriftVault(
action === "buy"
? PositionDirection.LONG
: PositionDirection.SHORT,
marketIndex: perpMarketIndexAndType.marketIndex,
marketIndex: perpMarketAccount.marketIndex,
}),
),
]);

View File

@@ -52,3 +52,4 @@ export * from "./flash_close_trade";
export * from "./create_3land_collectible";
export * from "./drift_vault";
export * from "./drift";