feat: add drift tools to langchain (#212)

# Pull Request Description

This PR is the langchain implementation of #207 

## Changes Made
This PR adds the following changes:
<!-- List the key changes made in this PR -->
- This PR adds files that implement the drift actions in a way
compatible with langchain
  
## Implementation Details
<!-- Provide technical details about the implementation -->
- Just a quick conversion of the drift actions to langchain tool classes

## Transaction executed by agent and prompt used
<!-- If applicable, provide example usage, transactions, or screenshots
-->
Example transaction: 
<img width="998" alt="Screenshot 2025-01-15 at 17 43 42"
src="https://github.com/user-attachments/assets/25f12c26-0f1a-470a-a566-028a54adf995"
/>
<img width="998" alt="Screenshot 2025-01-15 at 17 43 27"
src="https://github.com/user-attachments/assets/b07c6089-f5fc-4498-9d5a-14c5698c21a9"
/>
<img width="998" alt="Screenshot 2025-01-15 at 17 43 02"
src="https://github.com/user-attachments/assets/69067241-bb22-429b-9021-024c526ec25f"
/>

## Additional Notes
<!-- Any additional information that reviewers should know -->

## Checklist
- [x] I have tested these changes locally
- [ ] I have updated the documentation
- [ ] I have added a transaction link
- [x] I have added the prompt used to test it
This commit is contained in:
aryan
2025-01-15 22:19:05 +05:30
committed by GitHub
17 changed files with 621 additions and 0 deletions

View File

@@ -0,0 +1,38 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaCreateDriftUserAccountTool extends Tool {
name = "create_drift_user_account";
description = `Create a new user account with a deposit on Drift protocol.
Inputs (JSON string):
- amount: number, amount of the token to deposit (required)
- symbol: string, symbol of the token to deposit (required)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const res = await this.solanaKit.createDriftUserAccount(
parsedInput.amount,
parsedInput.symbol,
);
return JSON.stringify({
status: "success",
message: `User account created with ${parsedInput.amount} ${parsedInput.symbol} successfully deposited`,
account: res.account,
signature: res.txSignature,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CREATE_DRIFT_USER_ACCOUNT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,42 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaCreateDriftVaultTool extends Tool {
name = "create_drift_vault";
description = `Create a new drift vault delegating the agents address as the owner.
Inputs (JSON string):
- name: string, unique vault name (min 5 chars)
- marketName: string, market name in TOKEN-SPOT format
- redeemPeriod: number, days to wait before funds can be redeemed (min 1)
- maxTokens: number, maximum tokens vault can accommodate (min 100)
- minDepositAmount: number, minimum deposit amount
- managementFee: number, fee percentage for managing funds (max 20)
- profitShare: number, profit sharing percentage (max 90, default 5)
- hurdleRate: number, optional hurdle rate
- permissioned: boolean, whether vault has whitelist`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.createDriftVault(parsedInput);
return JSON.stringify({
status: "success",
message: "Drift vault created successfully",
vaultName: parsedInput.name,
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CREATE_DRIFT_VAULT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,37 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaDepositIntoDriftVaultTool extends Tool {
name = "deposit_into_drift_vault";
description = `Deposit funds into an existing drift vault.
Inputs (JSON string):
- vaultAddress: string, address of the vault (required)
- amount: number, amount to deposit (required)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.depositIntoDriftVault(
parsedInput.amount,
parsedInput.vaultAddress,
);
return JSON.stringify({
status: "success",
message: "Funds deposited successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DEPOSIT_INTO_VAULT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,39 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaDepositToDriftUserAccountTool extends Tool {
name = "deposit_to_drift_user_account";
description = `Deposit funds into your drift user account.
Inputs (JSON string):
- amount: number, amount to deposit (required)
- symbol: string, token symbol (required)
- repay: boolean, whether to repay borrowed funds (optional, default: false)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.depositToDriftUserAccount(
parsedInput.amount,
parsedInput.symbol,
parsedInput.repay,
);
return JSON.stringify({
status: "success",
message: "Funds deposited successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DEPOSIT_TO_DRIFT_ACCOUNT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,32 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaDeriveVaultAddressTool extends Tool {
name = "derive_drift_vault_address";
description = `Derive a drift vault address from the vault's name.
Inputs (JSON string):
- name: string, name of the vault to derive the address of (required)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const address = await this.solanaKit.deriveDriftVaultAddress(input);
return JSON.stringify({
status: "success",
message: "Vault address derived successfully",
address,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DERIVE_VAULT_ADDRESS_ERROR",
});
}
}
}

View File

@@ -0,0 +1,38 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaCheckDriftAccountTool extends Tool {
name = "does_user_have_drift_account";
description = `Check if a user has a Drift account.
Inputs: No inputs required - checks the current user's account`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(_input: string): Promise<string> {
try {
const res = await this.solanaKit.doesUserHaveDriftAccount();
if (!res.hasAccount) {
return JSON.stringify({
status: "error",
message: "You do not have a Drift account",
});
}
return JSON.stringify({
status: "success",
message: "Nice! You have a Drift account",
account: res.account,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "CHECK_DRIFT_ACCOUNT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,29 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaDriftUserAccountInfoTool extends Tool {
name = "drift_user_account_info";
description = `Get information about your drift account.
Inputs: No inputs required - retrieves current user's account info`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(_input: string): Promise<string> {
try {
const accountInfo = await this.solanaKit.driftUserAccountInfo();
return JSON.stringify({
status: "success",
data: accountInfo,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DRIFT_ACCOUNT_INFO_ERROR",
});
}
}
}

View File

@@ -0,0 +1,15 @@
export * from "./create_user_account";
export * from "./create_vault";
export * from "./deposit_into_vault";
export * from "./deposit_to_user_account";
export * from "./derive_vault_address";
export * from "./does_user_have_drift_account";
export * from "./drift_user_account_info";
export * from "./request_withdrawal";
export * from "./trade_delegated_vault";
export * from "./trade_perp_account";
export * from "./update_drift_vault_delegate";
export * from "./update_vault";
export * from "./vault_info";
export * from "./withdraw_from_account";
export * from "./withdraw_from_vault";

View File

@@ -0,0 +1,37 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaRequestDriftWithdrawalTool extends Tool {
name = "request_withdrawal_from_drift_vault";
description = `Request a withdrawal from an existing drift vault.
Inputs (JSON string):
- vaultAddress: string, vault address (required)
- amount: number, amount of shares to withdraw (required)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.requestWithdrawalFromDriftVault(
parsedInput.amount,
parsedInput.vaultAddress,
);
return JSON.stringify({
status: "success",
message: "Withdrawal request successful",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "REQUEST_DRIFT_WITHDRAWAL_ERROR",
});
}
}
}

View File

@@ -0,0 +1,49 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaTradeDelegatedDriftVaultTool extends Tool {
name = "trade_delegated_drift_vault";
description = `Carry out trades in a Drift vault.
Inputs (JSON string):
- vaultAddress: string, address of the Drift vault
- amount: number, amount to trade
- symbol: string, symbol of the token to trade
- action: "long" | "short", trade direction
- type: "market" | "limit", order type
- price: number, optional limit price`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.tradeUsingDelegatedDriftVault(
parsedInput.vaultAddress,
parsedInput.amount,
parsedInput.symbol,
parsedInput.action,
parsedInput.type,
parsedInput.price,
);
return JSON.stringify({
status: "success",
message:
parsedInput.type === "limit"
? "Order placed successfully"
: "Trade successful",
transactionId: tx,
...parsedInput,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "TRADE_DRIFT_VAULT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,42 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaTradeDriftPerpAccountTool extends Tool {
name = "trade_drift_perp_account";
description = `Trade a perpetual account on Drift protocol.
Inputs (JSON string):
- amount: number, amount to trade (required)
- symbol: string, token symbol (required)
- action: "long" | "short", trade direction (required)
- type: "market" | "limit", order type (required)
- price: number, required for limit orders`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const signature = await this.solanaKit.tradeUsingDriftPerpAccount(
parsedInput.amount,
parsedInput.symbol,
parsedInput.action,
parsedInput.type,
parsedInput.price,
);
return JSON.stringify({
status: "success",
signature,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "TRADE_PERP_ACCOUNT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,37 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaUpdateDriftVaultDelegateTool extends Tool {
name = "update_drift_vault_delegate";
description = `Update the delegate of a drift vault.
Inputs (JSON string):
- vaultAddress: string, address of the vault (required)
- newDelegate: string, address of the new delegate (required)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.updateDriftVaultDelegate(
parsedInput.vaultAddress,
parsedInput.newDelegate,
);
return JSON.stringify({
status: "success",
message: "Vault delegate updated successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "UPDATE_DRIFT_VAULT_DELEGATE_ERROR",
});
}
}
}

View File

@@ -0,0 +1,52 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaUpdateDriftVaultTool extends Tool {
name = "update_drift_vault";
description = `Update an existing drift vault with new settings.
Inputs (JSON string):
- vaultAddress: string, vault address (required)
- redeemPeriod: number, days until redemption (optional)
- maxTokens: number, maximum tokens allowed (optional)
- minDepositAmount: number, minimum deposit amount (optional)
- managementFee: number, management fee percentage (optional)
- profitShare: number, profit sharing percentage (optional)
- hurdleRate: number, hurdle rate (optional)
- permissioned: boolean, whitelist requirement (optional)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.updateDriftVault(
parsedInput.vaultAddress,
// @ts-expect-error - type mismatch
{
hurdleRate: parsedInput.hurdleRate,
maxTokens: parsedInput.maxTokens,
minDepositAmount: parsedInput.minDepositAmount,
profitShare: parsedInput.profitShare,
managementFee: parsedInput.managementFee,
permissioned: parsedInput.permissioned,
redeemPeriod: parsedInput.redeemPeriod,
},
);
return JSON.stringify({
status: "success",
message: "Drift vault parameters updated successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "UPDATE_DRIFT_VAULT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,32 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaDriftVaultInfoTool extends Tool {
name = "drift_vault_info";
description = `Get information about a drift vault.
Inputs (JSON string):
- vaultNameOrAddress: string, name or address of the vault (required)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const vaultInfo = await this.solanaKit.getDriftVaultInfo(input);
return JSON.stringify({
status: "success",
message: "Vault info retrieved successfully",
data: vaultInfo,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "DRIFT_VAULT_INFO_ERROR",
});
}
}
}

View File

@@ -0,0 +1,39 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaWithdrawFromDriftAccountTool extends Tool {
name = "withdraw_from_drift_account";
description = `Withdraw or borrow funds from your drift account.
Inputs (JSON string):
- amount: number, amount to withdraw (required)
- symbol: string, token symbol (required)
- isBorrow: boolean, whether to borrow funds instead of withdrawing (optional, default: false)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const parsedInput = JSON.parse(input);
const tx = await this.solanaKit.withdrawFromDriftAccount(
parsedInput.amount,
parsedInput.symbol,
parsedInput.isBorrow,
);
return JSON.stringify({
status: "success",
message: "Funds withdrawn successfully",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "WITHDRAW_FROM_DRIFT_ACCOUNT_ERROR",
});
}
}
}

View File

@@ -0,0 +1,32 @@
import { Tool } from "langchain/tools";
import { SolanaAgentKit } from "../../agent";
export class SolanaWithdrawFromDriftVaultTool extends Tool {
name = "withdraw_from_drift_vault";
description = `Withdraw funds from a vault given the redemption time has elapsed.
Inputs (JSON string):
- vaultAddress: string, vault address (required)`;
constructor(private solanaKit: SolanaAgentKit) {
super();
}
protected async _call(input: string): Promise<string> {
try {
const tx = await this.solanaKit.withdrawFromDriftVault(input);
return JSON.stringify({
status: "success",
message: "Withdrawal successful",
signature: tx,
});
} catch (error: any) {
return JSON.stringify({
status: "error",
message: error.message,
code: error.code || "WITHDRAW_FROM_DRIFT_VAULT_ERROR",
});
}
}
}

View File

@@ -25,6 +25,7 @@ export * from "./sns";
export * from "./lightprotocol";
export * from "./squads";
export * from "./helius";
export * from "./drift";
import { SolanaAgentKit } from "../agent";
import {
@@ -98,6 +99,21 @@ import {
SolanaDeleteHeliusWebhookTool,
SolanaParseTransactionHeliusTool,
SolanaGetAllAssetsByOwner,
SolanaCheckDriftAccountTool,
SolanaCreateDriftUserAccountTool,
SolanaCreateDriftVaultTool,
SolanaDepositIntoDriftVaultTool,
SolanaDepositToDriftUserAccountTool,
SolanaDeriveVaultAddressTool,
SolanaDriftUserAccountInfoTool,
SolanaDriftVaultInfoTool,
SolanaRequestDriftWithdrawalTool,
SolanaTradeDelegatedDriftVaultTool,
SolanaTradeDriftPerpAccountTool,
SolanaUpdateDriftVaultDelegateTool,
SolanaUpdateDriftVaultTool,
SolanaWithdrawFromDriftAccountTool,
SolanaWithdrawFromDriftVaultTool,
} from "./index";
export function createSolanaTools(solanaKit: SolanaAgentKit) {
@@ -177,5 +193,20 @@ export function createSolanaTools(solanaKit: SolanaAgentKit) {
new SolanaHeliusWebhookTool(solanaKit),
new SolanaGetHeliusWebhookTool(solanaKit),
new SolanaDeleteHeliusWebhookTool(solanaKit),
new SolanaCreateDriftUserAccountTool(solanaKit),
new SolanaCreateDriftVaultTool(solanaKit),
new SolanaDepositIntoDriftVaultTool(solanaKit),
new SolanaDepositToDriftUserAccountTool(solanaKit),
new SolanaDeriveVaultAddressTool(solanaKit),
new SolanaCheckDriftAccountTool(solanaKit),
new SolanaDriftUserAccountInfoTool(solanaKit),
new SolanaRequestDriftWithdrawalTool(solanaKit),
new SolanaTradeDelegatedDriftVaultTool(solanaKit),
new SolanaTradeDriftPerpAccountTool(solanaKit),
new SolanaUpdateDriftVaultDelegateTool(solanaKit),
new SolanaUpdateDriftVaultTool(solanaKit),
new SolanaDriftVaultInfoTool(solanaKit),
new SolanaWithdrawFromDriftAccountTool(solanaKit),
new SolanaWithdrawFromDriftVaultTool(solanaKit),
];
}