feat + fix: add action to get user drift account info and fix deposit, borrow, lend and withdraw on drift

This commit is contained in:
michaelessiet
2025-01-13 21:17:53 +01:00
parent 25f0f503cb
commit 60adc8d8c5
7 changed files with 183 additions and 36 deletions

View File

@@ -12,15 +12,17 @@ import {
PositionDirection,
PostOnlyParams,
PRICE_PRECISION,
QUOTE_PRECISION,
User,
type IWallet,
} from "@drift-labs/sdk";
import type { SolanaAgentKit } from "../agent";
import * as anchor from "@coral-xyz/anchor";
import { IDL, VAULT_PROGRAM_ID, VaultClient } from "@drift-labs/vaults-sdk";
import { BN } from "bn.js";
import { getAssociatedTokenAddressSync } from "@solana/spl-token";
import { PublicKey } from "@solana/web3.js";
import { Transaction } from "@solana/web3.js";
import { ComputeBudgetProgram } from "@solana/web3.js";
export async function initClients(agent: SolanaAgentKit) {
const wallet: IWallet = {
@@ -103,7 +105,7 @@ export async function createDriftUserAccount(
}
if (!userAccountExists) {
const depositAmount = new BN(amount).mul(token.precision);
const depositAmount = numberToSafeBN(amount, token.precision);
const [txSignature, account] =
await driftClient.initializeUserAccountAndDepositCollateral(
depositAmount,
@@ -130,18 +132,18 @@ export async function createDriftUserAccount(
* @param agent
* @param amount
* @param symbol
* @param address
* @param isRepay
* @returns
*/
export async function depositToDriftUserAccount(
agent: SolanaAgentKit,
amount: number,
symbol: string,
address?: string,
isRepay = false,
) {
try {
const { driftClient, cleanUp } = await initClients(agent);
const publicKey = address ? new PublicKey(address) : agent.wallet.publicKey;
const publicKey = agent.wallet.publicKey;
const user = new User({
driftClient,
userAccountPublicKey: getUserAccountPublicKeySync(
@@ -162,11 +164,29 @@ export async function depositToDriftUserAccount(
throw new Error("You need to create a Drift user account first.");
}
const depositAmount = new BN(amount).mul(token.precision);
const txSignature = await driftClient.deposit(
depositAmount,
token.marketIndex,
getAssociatedTokenAddressSync(token.mint, publicKey),
const depositAmount = numberToSafeBN(amount, token.precision);
const [depInstruction, latestBlockhash] = await Promise.all([
driftClient.getDepositTxnIx(
depositAmount,
token.marketIndex,
getAssociatedTokenAddressSync(token.mint, publicKey),
undefined,
isRepay,
),
driftClient.connection.getLatestBlockhash(),
]);
const tx = new Transaction().add(...depInstruction).add(
ComputeBudgetProgram.setComputeUnitPrice({
microLamports: 0.000001 * 1000000 * 1000000,
}),
);
tx.recentBlockhash = latestBlockhash.blockhash;
tx.sign(agent.wallet);
const txSignature = await driftClient.txSender.sendRawTransaction(
tx.serialize(),
{ ...driftClient.opts },
);
await cleanUp();
@@ -181,6 +201,7 @@ export async function withdrawFromDriftUserAccount(
agent: SolanaAgentKit,
amount: number,
symbol: string,
isBorrow = false,
) {
try {
const { driftClient, cleanUp } = await initClients(agent);
@@ -207,10 +228,27 @@ export async function withdrawFromDriftUserAccount(
const withdrawAmount = numberToSafeBN(amount, token.precision);
const txSignature = await driftClient.withdraw(
withdrawAmount,
token.marketIndex,
getAssociatedTokenAddressSync(token.mint, agent.wallet.publicKey),
const [withdrawInstruction, latestBlockhash] = await Promise.all([
driftClient.getWithdrawalIxs(
withdrawAmount,
token.marketIndex,
getAssociatedTokenAddressSync(token.mint, agent.wallet.publicKey),
!isBorrow,
),
driftClient.connection.getLatestBlockhash(),
]);
const tx = new Transaction().add(...withdrawInstruction).add(
ComputeBudgetProgram.setComputeUnitPrice({
microLamports: 0.000001 * 1000000 * 1000000,
}),
);
tx.recentBlockhash = latestBlockhash.blockhash;
tx.sign(agent.wallet);
const txSignature = await driftClient.txSender.sendRawTransaction(
tx.serialize(),
{ ...driftClient.opts },
);
await cleanUp();
@@ -289,6 +327,9 @@ export async function driftPerpTrade(
price: numberToSafeBN(params.price, PRICE_PRECISION),
postOnly: PostOnlyParams.SLIDE,
}),
{
computeUnitsPrice: 0.000001 * 1000000 * 1000000,
},
);
} else {
signature = await driftClient.placePerpOrder(
@@ -301,6 +342,9 @@ export async function driftPerpTrade(
: PositionDirection.SHORT,
marketIndex: market.marketIndex,
}),
{
computeUnitsPrice: 0.000001 * 1000000 * 1000000,
},
);
}
@@ -342,3 +386,61 @@ export async function doesUserHaveDriftAccount(agent: SolanaAgentKit) {
throw new Error(`Failed to check user account: ${e.message}`);
}
}
/**
* Get account info for a drift User
* @param agent
* @returns
*/
export async function driftUserAccountInfo(agent: SolanaAgentKit) {
try {
const { driftClient, cleanUp } = await initClients(agent);
const user = new User({
driftClient,
userAccountPublicKey: getUserAccountPublicKeySync(
new PublicKey(DRIFT_PROGRAM_ID),
agent.wallet.publicKey,
),
});
const userAccountExists = await user.exists();
if (!userAccountExists) {
throw new Error("User account does not exist");
}
await user.subscribe();
const account = user.getUserAccount();
await user.unsubscribe();
await cleanUp();
const perpPositions = account.perpPositions.map((pos) => ({
...pos,
baseAssetAmount: convertToNumber(pos.baseAssetAmount, BASE_PRECISION),
settledPnl: convertToNumber(pos.settledPnl, QUOTE_PRECISION),
}));
const spotPositions = account.spotPositions.map((pos) => ({
...pos,
scaledBalance: convertToNumber(pos.scaledBalance, BASE_PRECISION),
cumulativeDeposits: convertToNumber(
pos.cumulativeDeposits,
BASE_PRECISION,
),
symbol: MainnetSpotMarkets.find((v) => v.marketIndex === pos.marketIndex)
?.symbol,
}));
return {
...account,
name: account.name,
authority: account.authority,
totalDeposits: `$${convertToNumber(account.totalDeposits, QUOTE_PRECISION)}`,
totalWithdraws: `$${convertToNumber(account.totalWithdraws, QUOTE_PRECISION)}`,
settledPerpPnl: `$${convertToNumber(account.settledPerpPnl, QUOTE_PRECISION)}`,
lastActiveSlot: account.lastActiveSlot.toNumber(),
perpPositions,
spotPositions,
};
} catch (e) {
// @ts-expect-error - error message is a string
throw new Error(`Failed to check user account: ${e.message}`);
}
}

View File

@@ -134,7 +134,7 @@ export async function createVault(
.toNumber(),
minDepositAmount: numberToSafeBN(params.minDepositAmount, spotPrecision),
redeemPeriod: new BN(params.redeemPeriod * 86400),
maxTokens: new BN(params.maxTokens).mul(spotPrecision),
maxTokens: numberToSafeBN(params.maxTokens, spotPrecision),
managementFee: new BN(params.managementFee)
.mul(PERCENTAGE_PRECISION)
.div(new BN(100)),