diff --git a/Cargo.lock b/Cargo.lock index be9d492..fcf77ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2346,6 +2346,7 @@ dependencies = [ name = "ore" version = "0.1.0" dependencies = [ + "arrayvec", "bincode", "bs64", "bytemuck", diff --git a/Cargo.toml b/Cargo.toml index 16ce2d0..1ebf2b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ no-entrypoint = [] default = [] [dependencies] +arrayvec = { version = "0.7.4", default-features = false } bincode = "1.3.3" bytemuck = "1.14.3" num_enum = "0.7.2" diff --git a/src/entrypoint_nostd.rs b/src/entrypoint_nostd.rs new file mode 100644 index 0000000..9b7649d --- /dev/null +++ b/src/entrypoint_nostd.rs @@ -0,0 +1,555 @@ +extern crate alloc; +use alloc::rc::Rc; + +use core::{cell::RefCell, marker::PhantomData, mem::size_of, ptr::NonNull, slice::from_raw_parts}; + +use arrayvec::ArrayVec; +use bytemuck::{Pod, Zeroable}; +use solana_program::{ + account_info::AccountInfo, + entrypoint::{BPF_ALIGN_OF_U128, MAX_PERMITTED_DATA_INCREASE, NON_DUP_MARKER}, + log, + pubkey::Pubkey, +}; + +#[macro_export] +macro_rules! entrypoint_nostd { + ($process_instruction:ident, $accounts:literal) => { + #[no_mangle] + pub unsafe extern "C" fn entrypoint(input: *mut u8) -> u64 { + let (program_id, accounts, instruction_data) = + unsafe { $crate::deserialize_nostd::<$accounts>(input) }; + match $process_instruction(&program_id, &accounts, &instruction_data) { + Ok(()) => solana_program::entrypoint::SUCCESS, + Err(error) => error.into(), + } + } + }; +} + +#[macro_export] +macro_rules! entrypoint_nostd_no_duplicates { + ($process_instruction:ident, $accounts:literal) => { + #[no_mangle] + pub unsafe extern "C" fn entrypoint(input: *mut u8) -> u64 { + let Some((program_id, accounts, instruction_data)) = + $crate::deserialize_nostd_no_dup::<$accounts>(input) + else { + // TODO: better error + solana_program::log::sol_log("a duplicate account was found"); + return u64::MAX; + }; + // solana_program::entrypoint::SUCCESS + match $process_instruction(&program_id, &accounts, &instruction_data) { + Ok(()) => solana_program::entrypoint::SUCCESS, + Err(error) => error.into(), + } + } + }; +} + +pub unsafe fn deserialize_nostd<'a, const MAX_ACCOUNTS: usize>( + input: *mut u8, +) -> ( + &'a Pubkey, + ArrayVec, + &'a [u8], +) { + let mut offset: usize = 0; + + // Number of accounts present + #[allow(clippy::cast_ptr_alignment)] + let num_accounts = *(input.add(offset) as *const u64) as usize; + offset += size_of::(); + + // Account Infos + let mut accounts = ArrayVec::new(); + for _ in 0..num_accounts { + let dup_info = *(input.add(offset) as *const u8); + if dup_info == NON_DUP_MARKER { + // MAGNETAR FIELDS: safety depends on alignment, size + // 1) we will always be 8 byte aligned due to align_offset + // 2) solana vm serialization format is consistent so size is ok + let account_info: &mut NoStdAccountInfo4Inner = + core::mem::transmute::<&mut u8, _>(&mut *(input.add(offset))); + // bytemuck::try_from_bytes_mut(from_raw_parts_mut(input.add(offset), 88)).unwrap(); + + offset += size_of::(); + offset += account_info.data_len; + offset += MAX_PERMITTED_DATA_INCREASE; + offset += (offset as *const u8).align_offset(BPF_ALIGN_OF_U128); + offset += size_of::(); // MAGNETAR FIELDS: ignore rent epoch + + // MAGNETAR FIELDS: reset borrow state right before pushing + account_info.borrow_state = 0b_0000_0000; + if accounts + .try_push(NoStdAccountInfo4 { + inner: account_info, + }) + .is_err() + { + log::sol_log("ArrayVec is full. Truncating input accounts"); + }; + } else { + offset += 8; + + // Duplicate account, clone the original + if accounts + .try_push(accounts[dup_info as usize].clone()) + .is_err() + { + log::sol_log("ArrayVec is full. Truncating input accounts"); + }; + } + } + + // Instruction data + #[allow(clippy::cast_ptr_alignment)] + let instruction_data_len = *(input.add(offset) as *const u64) as usize; + offset += size_of::(); + + let instruction_data = { from_raw_parts(input.add(offset), instruction_data_len) }; + offset += instruction_data_len; + + // Program Id + let program_id: &Pubkey = &*(input.add(offset) as *const Pubkey); + + (program_id, accounts, instruction_data) +} + +pub unsafe fn deserialize_nostd_no_dup<'a, const MAX_ACCOUNTS: usize>( + input: *mut u8, +) -> Option<( + &'a Pubkey, + ArrayVec, + &'a [u8], +)> { + let mut offset: usize = 0; + + // Number of accounts present + #[allow(clippy::cast_ptr_alignment)] + let num_accounts = *(input.add(offset) as *const u64) as usize; + offset += size_of::(); + + // Account Infos + let mut accounts = ArrayVec::new(); + for _ in 0..num_accounts { + let dup_info = *(input.add(offset) as *const u8); + if dup_info == NON_DUP_MARKER { + // MAGNETAR FIELDS: safety depends on alignment, size + // 1) we will always be 8 byte aligned due to align_offset + // 2) solana vm serialization format is consistent so size is ok + let account_info: &mut NoStdAccountInfo4Inner = + core::mem::transmute::<&mut u8, _>(&mut *(input.add(offset))); + // bytemuck::try_from_bytes_mut(from_raw_parts_mut(input.add(offset), 88)).unwrap(); + offset += size_of::(); + offset += account_info.data_len; + offset += MAX_PERMITTED_DATA_INCREASE; + offset += (offset as *const u8).align_offset(BPF_ALIGN_OF_U128); + offset += size_of::(); // MAGNETAR FIELDS: ignore rent epoch + + // MAGNETAR FIELDS: reset borrow state right before pushing + account_info.borrow_state = 0b_0000_0000; + if accounts + .try_push(NoStdAccountInfo4 { + inner: account_info, + }) + .is_err() + { + log::sol_log("ArrayVec is full. Truncating input accounts"); + }; + } else { + return None; + } + } + + // Instruction data + #[allow(clippy::cast_ptr_alignment)] + let instruction_data_len = *(input.add(offset) as *const u64) as usize; + offset += size_of::(); + + let instruction_data = { from_raw_parts(input.add(offset), instruction_data_len) }; + offset += instruction_data_len; + + // Program Id + let program_id: &Pubkey = &*(input.add(offset) as *const Pubkey); + + Some((program_id, accounts, instruction_data)) +} + +#[derive(Clone, PartialEq, Eq)] +#[repr(C)] +pub struct NoStdAccountInfo4 { + inner: *mut NoStdAccountInfo4Inner, +} + +impl NoStdAccountInfo4 { + /// SAFETY: you must ensure that this pointer IS + REMAINS valid. + pub unsafe fn from(inner: *mut NoStdAccountInfo4Inner) -> NoStdAccountInfo4 { + NoStdAccountInfo4 { inner } + } +} + +#[derive(Clone, Pod, Zeroable, Copy, Default)] +#[repr(C)] +pub struct NoStdAccountInfo4Inner { + /// 0) We reuse the duplicate flag for this. We set it to 0b0000_0000. + /// 1) We use the first four bits to track state of lamport borrow + /// 2) We use the second four bits to track state of data borrow + /// + /// 4 bit state: [1 bit mutable borrow flag | u3 immmutable borrow flag] + /// This gives us up to 7 immutable borrows. Note that does not mean 7 + /// duplicate account infos, but rather 7 calls to borrow lamports or + /// borrow data across all duplicate account infos. + borrow_state: u8, + + /// Was the transaction signed by this account's public key? + is_signer: u8, + + /// Is the account writable? + is_writable: u8, + + /// This account's data contains a loaded program (and is now read-only) + executable: u8, + + padding: u32, + + /// Public key of the account + key: Pubkey, + /// Program that owns this account + owner: Pubkey, + + /// The lamports in the account. Modifiable by programs. + lamports: u64, + data_len: usize, +} + +#[repr(C)] +pub struct AccountMetaC { + pubkey: *const Pubkey, + is_writable: bool, + is_signer: bool, +} + +pub struct AccountInfoC { + pub key: *const Pubkey, /* Public key of the account */ + pub lamports: *const u64, /* Number of lamports owned by this account */ + pub data_len: u64, /* Length of data in bytes */ + pub data: *const u8, /* On-chain data within this account */ + pub owner: *const Pubkey, /* Program that owns this account */ + pub rent_epoch: u64, /* The epoch at which this account will next owe rent */ + pub is_signer: bool, /* Transaction was signed by this account's key? */ + pub is_writable: bool, /* Is the account writable? */ + pub executable: bool, /* This account's data contains a loaded program (and is now read-only) */ +} + +impl AccountInfoC { + #[inline(always)] + pub fn to_meta_c(&self) -> AccountMetaC { + AccountMetaC { + pubkey: self.key, + is_writable: self.is_writable, + is_signer: self.is_signer, + } + } + #[inline(always)] + pub fn to_meta_c_signer(&self) -> AccountMetaC { + AccountMetaC { + pubkey: self.key, + is_writable: self.is_writable, + is_signer: true, + } + } +} + +#[derive(Debug, PartialEq)] +#[repr(C)] +pub struct InstructionC { + pub program_id: *const Pubkey, + pub accounts: *const AccountMetaC, + pub accounts_len: u64, + pub data: *const u8, + pub data_len: u64, +} + +pub struct Ref<'a, T: ?Sized> { + value: &'a T, + state: NonNull, + is_lamport: bool, +} + +impl<'a, T: ?Sized> core::ops::Deref for Ref<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> Drop for Ref<'a, T> { + // We just need to decrement the immutable borrow count + fn drop(&mut self) { + if self.is_lamport { + unsafe { *self.state.as_mut() -= 1 << 4 }; + } else { + unsafe { *self.state.as_mut() -= 1 }; + } + } +} + +pub struct RefMut<'a, T: ?Sized> { + value: &'a mut T, + state: NonNull, + is_lamport: bool, +} + +impl<'a, T: ?Sized> core::ops::Deref for RefMut<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + self.value + } +} +impl<'a, T: ?Sized> core::ops::DerefMut for RefMut<'a, T> { + fn deref_mut(&mut self) -> &mut ::Target { + self.value + } +} + +impl<'a, T: ?Sized> Drop for RefMut<'a, T> { + // We need to unset the mut borrow flag + fn drop(&mut self) { + if self.is_lamport { + unsafe { *self.state.as_mut() &= 0b_0111_1111 }; + } else { + unsafe { *self.state.as_mut() &= 0b_1111_0111 }; + } + } +} + +/// SAFETY: +/// Within the standard library, RcBox uses repr(C) which guarantees +/// we will always have the layout +/// +/// strong: isize, +/// weak: isize, +/// value: T +/// +/// For us, T -> RefCell. Since RefCell has T: ?Sized, this +/// guarantees that the inner fields of RefCell are not reordered. +/// So, in conclusion, this type has a stable memory layout. +#[repr(C, align(8))] +pub struct RcRefCellInner<'a, T> { + strong: isize, + weak: isize, + refcell: RefCell, + phantom_data: PhantomData<&'a mut ()>, +} + +impl<'a, T> RcRefCellInner<'a, T> { + pub fn new(value: T) -> RcRefCellInner<'a, T> { + RcRefCellInner { + strong: 2, + weak: 2, + refcell: RefCell::new(value), + phantom_data: PhantomData, + } + } + + /// NOTE: when the last Rc is dropped, the strong count will reach + /// one. So, it will not deallocate, which is fine because the + /// Rc points to stack memory. + /// + /// SAFETY: [RcRefCellInner] must NOT be dropped before this Rc is + /// used. There can be no safe abstraction that guarantees users + /// do this because we cannot make Rc inherit the borrowed + /// lifetime. + unsafe fn as_rcrc(&self) -> Rc> { + // Rc::from_raw expects pointer to T + unsafe { Rc::from_raw(&self.refcell as *const RefCell) } + } +} + +#[inline(always)] +const fn offset(ptr: *const T, offset: usize) -> *const U { + unsafe { (ptr as *const u8).add(offset) as *const U } +} + +impl NoStdAccountInfo4 { + pub fn to_info_c(&self) -> AccountInfoC { + AccountInfoC { + key: offset(self.inner, 8), + lamports: offset(self.inner, 72), + data_len: self.data_len() as u64, + data: offset(self.inner, 88), + owner: offset(self.inner, 40), + rent_epoch: 0, + is_signer: self.is_signer(), + is_writable: self.is_writable(), + executable: self.executable(), + } + } + pub fn to_meta_c(&self) -> AccountMetaC { + AccountMetaC { + pubkey: offset(self.inner, 8), + is_writable: self.is_writable(), + is_signer: self.is_signer(), + } + } + + pub unsafe fn unchecked_info_prep<'a>( + &'a self, + ) -> (RcRefCellInner<&'a mut u64>, RcRefCellInner<&'a mut [u8]>) { + let lamports_inner = RcRefCellInner::new(self.unchecked_borrow_mut_lamports()); + let data_inner = RcRefCellInner::new(self.unchecked_borrow_mut_data()); + (lamports_inner, data_inner) + } + + pub unsafe fn info_with<'a>( + &'a self, + lamports_data: &'a (RcRefCellInner<&'a mut u64>, RcRefCellInner<&'a mut [u8]>), + ) -> AccountInfo<'a> { + let (lamports, data) = lamports_data; + AccountInfo { + key: self.key(), + lamports: unsafe { lamports.as_rcrc() }, + data: unsafe { data.as_rcrc() }, + owner: self.owner(), + rent_epoch: u64::MAX, + is_signer: self.is_signer(), + is_writable: self.is_writable(), + executable: self.executable(), + } + } + + #[inline(always)] + pub fn key(&self) -> &Pubkey { + unsafe { &(*self.inner).key } + } + #[inline(always)] + pub fn owner(&self) -> &Pubkey { + unsafe { &(*self.inner).owner } + } + #[inline(always)] + pub fn is_signer(&self) -> bool { + unsafe { (*self.inner).is_signer != 0 } + } + #[inline(always)] + pub fn is_writable(&self) -> bool { + unsafe { (*self.inner).is_writable != 0 } + } + #[inline(always)] + pub fn executable(&self) -> bool { + unsafe { (*self.inner).executable != 0 } + } + #[inline(always)] + pub fn data_len(&self) -> usize { + unsafe { (*self.inner).data_len } + } + + pub unsafe fn unchecked_borrow_lamports(&self) -> &u64 { + &(*self.inner).lamports + } + pub unsafe fn unchecked_borrow_mut_lamports(&self) -> &mut u64 { + &mut (*self.inner).lamports + } + pub unsafe fn unchecked_borrow_data(&self) -> &[u8] { + core::slice::from_raw_parts(self.data_ptr(), (*self.inner).data_len) + } + pub unsafe fn unchecked_borrow_mut_data(&self) -> &mut [u8] { + core::slice::from_raw_parts_mut(self.data_ptr(), (*self.inner).data_len) + } + + pub fn try_borrow_lamports(&self) -> Option> { + let borrow_state = unsafe { &mut (*self.inner).borrow_state }; + + // Check if mutable borrow is already taken + if *borrow_state & 0b_1000_0000 != 0 { + return None; + } + + // Check if we have reached the max immutable borrow count + if *borrow_state & 0b_0111_0000 == 0b_0111_0000 { + return None; + } + + // Increment the immutable borrow count + *borrow_state += 1 << 4; + + // Return the reference to lamports + Some(Ref { + value: unsafe { &(*self.inner).lamports }, + state: unsafe { NonNull::new_unchecked(&mut (*self.inner).borrow_state) }, + is_lamport: true, + }) + } + + pub fn try_borrow_mut_lamports(&self) -> Option> { + let borrow_state = unsafe { &mut (*self.inner).borrow_state }; + + // Check if any borrow (mutable or immutable) is already taken for lamports + if *borrow_state & 0b_1111_0000 != 0 { + return None; + } + + // Set the mutable lamport borrow flag + *borrow_state |= 0b_1000_0000; + + // Return the mutable reference to lamports + Some(RefMut { + value: unsafe { &mut (*self.inner).lamports }, + state: unsafe { NonNull::new_unchecked(&mut (*self.inner).borrow_state) }, + is_lamport: true, + }) + } + + pub fn try_borrow_data(&self) -> Option> { + let borrow_state = unsafe { &mut (*self.inner).borrow_state }; + + // Check if mutable data borrow is already taken (most significant bit of the data_borrow_state) + if *borrow_state & 0b_0000_1000 != 0 { + return None; + } + + // Check if we have reached the max immutable data borrow count (7) + if *borrow_state & 0b0111 == 0b0111 { + return None; + } + + // Increment the immutable data borrow count + *borrow_state += 1; + + // Return the reference to data + Some(Ref { + value: unsafe { core::slice::from_raw_parts(self.data_ptr(), (*self.inner).data_len) }, + state: unsafe { NonNull::new_unchecked(&mut (*self.inner).borrow_state) }, + is_lamport: false, + }) + } + + pub fn try_borrow_mut_data(&self) -> Option> { + let borrow_state = unsafe { &mut (*self.inner).borrow_state }; + + // Check if any borrow (mutable or immutable) is already taken for data + if *borrow_state & 0b_0000_1111 != 0 { + return None; + } + + // Set the mutable data borrow flag + *borrow_state |= 0b0000_1000; + + assert_eq!(self.data_ptr() as usize % 8, 0); // TODO REMOVE + + // Return the mutable reference to data + Some(RefMut { + value: unsafe { + core::slice::from_raw_parts_mut(self.data_ptr(), (*self.inner).data_len) + }, + state: unsafe { NonNull::new_unchecked(&mut (*self.inner).borrow_state) }, + is_lamport: false, + }) + } + + // private + fn data_ptr(&self) -> *mut u8 { + unsafe { (self.inner as *const _ as *mut u8).add(size_of::()) } + } +} diff --git a/src/instruction.rs b/src/instruction.rs index 16a04a5..94d7f63 100644 --- a/src/instruction.rs +++ b/src/instruction.rs @@ -3,7 +3,7 @@ use num_enum::TryFromPrimitive; use shank::ShankInstruction; use solana_program::pubkey::Pubkey; -use crate::{impl_to_bytes, state::Hash}; +use crate::{impl_instruction_from_bytes, impl_to_bytes, state::Hash}; #[repr(u8)] #[derive(Clone, Copy, Debug, Eq, PartialEq, ShankInstruction, TryFromPrimitive)] @@ -135,3 +135,10 @@ impl_to_bytes!(MineArgs); impl_to_bytes!(ClaimArgs); impl_to_bytes!(UpdateAdminArgs); impl_to_bytes!(UpdateDifficultyArgs); + +impl_instruction_from_bytes!(InitializeArgs); +impl_instruction_from_bytes!(CreateProofArgs); +impl_instruction_from_bytes!(MineArgs); +impl_instruction_from_bytes!(ClaimArgs); +impl_instruction_from_bytes!(UpdateAdminArgs); +impl_instruction_from_bytes!(UpdateDifficultyArgs); diff --git a/src/lib.rs b/src/lib.rs index 7bf8d70..37e578f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod entrypoint_nostd; pub mod error; pub mod instruction; mod loaders; @@ -20,6 +21,15 @@ declare_id!("CeJShZEAzBLwtcLQvbZc7UT38e4nUTn63Za5UFyYYDTS"); #[cfg(not(feature = "no-entrypoint"))] solana_program::entrypoint!(process_instruction); +// #[no_mangle] +// pub unsafe extern "C" fn entrypoint(input: *mut u8) -> u64 { +// let (program_id, accounts, instruction_data) = +// unsafe { crate::entrypoint_nostd::deserialize_nostd::<32>(input) }; +// match process_instruction(&program_id, &accounts, &instruction_data) { +// Ok(()) => solana_program::entrypoint::SUCCESS, +// Err(error) => error.into(), +// } +// } // TODO Set this before deployment /// The unix timestamp after which mining is allowed. diff --git a/src/loaders.rs b/src/loaders.rs index 726d4b2..233c669 100644 --- a/src/loaders.rs +++ b/src/loaders.rs @@ -13,6 +13,7 @@ pub fn load_signer<'a, 'info>(info: &'a AccountInfo<'info>) -> Result<(), Progra if !info.is_signer { return Err(ProgramError::MissingRequiredSignature); } + Ok(()) } @@ -23,12 +24,13 @@ pub fn load_bus<'a, 'info>( if info.owner.ne(&crate::id()) { return Err(ProgramError::InvalidAccountOwner); } + if info.data_is_empty() { return Err(ProgramError::UninitializedAccount); } let bus_data = info.data.borrow(); - let bus = bytemuck::try_from_bytes::(&bus_data).unwrap(); + let bus = Bus::try_from_bytes(&bus_data)?; if !(0..BUS_COUNT).contains(&(bus.id as usize)) { return Err(ProgramError::InvalidAccountData); @@ -49,12 +51,13 @@ pub fn load_proof<'a, 'info>( if info.owner.ne(&crate::id()) { return Err(ProgramError::InvalidAccountOwner); } + if info.data_is_empty() { return Err(ProgramError::UninitializedAccount); } let proof_data = info.data.borrow(); - let proof = bytemuck::try_from_bytes::(&proof_data).unwrap(); + let proof = Proof::try_from_bytes(&proof_data)?; if proof.authority.ne(&authority) { return Err(ProgramError::InvalidAccountData); @@ -74,6 +77,7 @@ pub fn load_treasury<'a, 'info>( if info.owner.ne(&crate::id()) { return Err(ProgramError::InvalidAccountOwner); } + if info.data_is_empty() { return Err(ProgramError::UninitializedAccount); } @@ -96,6 +100,7 @@ pub fn load_mint<'a, 'info>( if info.owner.ne(&spl_token::id()) { return Err(ProgramError::InvalidAccountOwner); } + if info.data_is_empty() { return Err(ProgramError::UninitializedAccount); } @@ -125,6 +130,7 @@ pub fn load_token_account<'a, 'info>( if info.owner.ne(&spl_token::id()) { return Err(ProgramError::InvalidAccountOwner); } + if info.data_is_empty() { return Err(ProgramError::UninitializedAccount); } @@ -136,6 +142,7 @@ pub fn load_token_account<'a, 'info>( if account.mint.ne(&mint) { return Err(ProgramError::InvalidAccountData); } + if let Some(owner) = owner { if account.owner.ne(owner) { return Err(ProgramError::InvalidAccountData); @@ -166,9 +173,11 @@ pub fn load_uninitialized_account<'a, 'info>( if info.owner.ne(&system_program::id()) { return Err(ProgramError::AccountAlreadyInitialized); } + if !info.data_is_empty() { return Err(ProgramError::AccountAlreadyInitialized); } + if !info.is_writable { return Err(ProgramError::InvalidAccountData); } @@ -190,9 +199,11 @@ pub fn load_account<'a, 'info>( if info.key.ne(&key) { return Err(ProgramError::InvalidAccountData); } + if is_writable && !info.is_writable { return Err(ProgramError::InvalidAccountData); } + Ok(()) } @@ -203,8 +214,10 @@ pub fn load_program<'a, 'info>( if info.key.ne(&key) { return Err(ProgramError::InvalidAccountData); } + if !info.executable { return Err(ProgramError::InvalidAccountData); } + Ok(()) } diff --git a/src/processor/claim.rs b/src/processor/claim.rs index caecef2..b92abb9 100644 --- a/src/processor/claim.rs +++ b/src/processor/claim.rs @@ -17,10 +17,9 @@ pub fn process_claim<'a, 'info>( data: &[u8], ) -> ProgramResult { // Parse args - let args = bytemuck::try_from_bytes::(data) - .or(Err(ProgramError::InvalidInstructionData))?; + let args = ClaimArgs::try_from_bytes(data)?; - // Validate accounts + // Load accounts let [signer, beneficiary_info, mint_info, proof_info, treasury_info, treasury_tokens_info, token_program] = accounts else { return Err(ProgramError::NotEnoughAccountKeys); }; @@ -38,7 +37,7 @@ pub fn process_claim<'a, 'info>( // Validate claim amout let mut proof_data = proof_info.data.borrow_mut(); - let mut proof = bytemuck::try_from_bytes_mut::(&mut proof_data).unwrap(); + let mut proof = Proof::try_from_bytes_mut(&mut proof_data)?; if proof.claimable_rewards.lt(&args.amount) { return Err(OreError::InvalidClaimAmount.into()); } @@ -48,7 +47,7 @@ pub fn process_claim<'a, 'info>( // Update lifetime status let mut treasury_data = treasury_info.data.borrow_mut(); - let mut treasury = bytemuck::try_from_bytes_mut::(&mut treasury_data).unwrap(); + let mut treasury = Treasury::try_from_bytes_mut(&mut treasury_data)?; treasury.total_claimed_rewards = treasury.total_claimed_rewards.saturating_add(args.amount); // Distribute tokens from treasury to beneficiary diff --git a/src/processor/create_proof.rs b/src/processor/create_proof.rs index 9ae5639..62b39aa 100644 --- a/src/processor/create_proof.rs +++ b/src/processor/create_proof.rs @@ -13,10 +13,9 @@ pub fn process_create_proof<'a, 'info>( data: &[u8], ) -> ProgramResult { // Parse args - let args = bytemuck::try_from_bytes::(data) - .or(Err(ProgramError::InvalidInstructionData))?; + let args = CreateProofArgs::try_from_bytes(data)?; - // Validate accounts + // Load accounts let [signer, proof_info, system_program] = accounts else { return Err(ProgramError::NotEnoughAccountKeys); }; @@ -34,7 +33,7 @@ pub fn process_create_proof<'a, 'info>( signer, )?; let mut proof_data = proof_info.data.borrow_mut(); - let mut proof = bytemuck::try_from_bytes_mut::(&mut proof_data).unwrap(); + let mut proof = Proof::try_from_bytes_mut(&mut proof_data)?; proof.bump = args.bump as u64; proof.authority = *signer.key; proof.claimable_rewards = 0; diff --git a/src/processor/initialize.rs b/src/processor/initialize.rs index 5165639..a3132b9 100644 --- a/src/processor/initialize.rs +++ b/src/processor/initialize.rs @@ -22,10 +22,9 @@ pub fn process_initialize<'a, 'info>( data: &[u8], ) -> ProgramResult { // Parse args - let args = bytemuck::try_from_bytes::(data) - .or(Err(ProgramError::InvalidInstructionData))?; + let args = InitializeArgs::try_from_bytes(data)?; - // Validate accounts + // Load accounts let [signer, bus_0_info, bus_1_info, bus_2_info, bus_3_info, bus_4_info, bus_5_info, bus_6_info, bus_7_info, mint_info, treasury_info, treasury_tokens_info, system_program, token_program, associated_token_program, rent_sysvar] = accounts else { return Err(ProgramError::NotEnoughAccountKeys); }; @@ -96,7 +95,7 @@ pub fn process_initialize<'a, 'info>( signer, )?; let mut treasury_data = treasury_info.data.borrow_mut(); - let mut treasury = bytemuck::try_from_bytes_mut::(&mut treasury_data).unwrap(); + let mut treasury = Treasury::try_from_bytes_mut(&mut treasury_data)?; treasury.bump = args.treasury_bump as u64; treasury.admin = *signer.key; treasury.epoch_start_at = 0; diff --git a/src/processor/mine.rs b/src/processor/mine.rs index bbcf283..8606448 100644 --- a/src/processor/mine.rs +++ b/src/processor/mine.rs @@ -26,10 +26,9 @@ pub fn process_mine<'a, 'info>( data: &[u8], ) -> ProgramResult { // Parse args - let args = - bytemuck::try_from_bytes::(data).or(Err(ProgramError::InvalidInstructionData))?; + let args = MineArgs::try_from_bytes(data)?; - // Validate accounts + // Load accounts let [signer, bus_info, proof_info, treasury_info, slot_hashes_info] = accounts else { return Err(ProgramError::NotEnoughAccountKeys); }; @@ -40,9 +39,9 @@ pub fn process_mine<'a, 'info>( load_sysvar(slot_hashes_info, sysvar::slot_hashes::id())?; // Validate epoch is active - let clock = Clock::get().unwrap(); + let clock = Clock::get().or(Err(ProgramError::InvalidAccountData))?; let treasury_data = treasury_info.data.borrow(); - let treasury = bytemuck::try_from_bytes::(&treasury_data).unwrap(); + let treasury = Treasury::try_from_bytes(&treasury_data)?; let epoch_end_at = treasury.epoch_start_at.saturating_add(EPOCH_DURATION); if clock.unix_timestamp.ge(&epoch_end_at) { return Err(OreError::EpochExpired.into()); @@ -50,7 +49,7 @@ pub fn process_mine<'a, 'info>( // Validate provided hash let mut proof_data = proof_info.data.borrow_mut(); - let mut proof = bytemuck::try_from_bytes_mut::(&mut proof_data).unwrap(); + let mut proof = Proof::try_from_bytes_mut(&mut proof_data)?; validate_hash( proof.hash.into(), args.hash.into(), @@ -61,7 +60,7 @@ pub fn process_mine<'a, 'info>( // Update claimable rewards let mut bus_data = bus_info.data.borrow_mut(); - let mut bus = bytemuck::try_from_bytes_mut::(&mut bus_data).unwrap(); + let mut bus = Bus::try_from_bytes_mut(&mut bus_data)?; if bus.available_rewards.lt(&treasury.reward_rate) { return Err(OreError::InsufficientBusRewards.into()); } diff --git a/src/processor/reset.rs b/src/processor/reset.rs index c876c64..d7eb1fd 100644 --- a/src/processor/reset.rs +++ b/src/processor/reset.rs @@ -16,7 +16,7 @@ pub fn process_reset<'a, 'info>( accounts: &'a [AccountInfo<'info>], _data: &[u8], ) -> ProgramResult { - // Validate accounts + // Load accounts let [signer, bus_0_info, bus_1_info, bus_2_info, bus_3_info, bus_4_info, bus_5_info, bus_6_info, bus_7_info, mint_info, treasury_info, treasury_tokens_info, token_program] = accounts else { return Err(ProgramError::NotEnoughAccountKeys); }; @@ -44,9 +44,9 @@ pub fn process_reset<'a, 'info>( ]; // Validate epoch has ended - let clock = Clock::get().unwrap(); + let clock = Clock::get().or(Err(ProgramError::InvalidAccountData))?; let mut treasury_data = treasury_info.data.borrow_mut(); - let mut treasury = bytemuck::try_from_bytes_mut::(&mut treasury_data).unwrap(); + let mut treasury = Treasury::try_from_bytes_mut(&mut treasury_data)?; let epoch_end_at = treasury.epoch_start_at.saturating_add(EPOCH_DURATION); if clock.unix_timestamp.lt(&epoch_end_at) { return Err(OreError::EpochActive.into()); @@ -56,7 +56,7 @@ pub fn process_reset<'a, 'info>( let mut total_available_rewards = 0u64; for i in 0..BUS_COUNT { let mut bus_data = busses[i].data.borrow_mut(); - let mut bus = bytemuck::try_from_bytes_mut::(&mut bus_data).unwrap(); + let mut bus = Bus::try_from_bytes_mut(&mut bus_data)?; total_available_rewards = total_available_rewards.saturating_add(bus.available_rewards); bus.available_rewards = BUS_EPOCH_REWARDS; } diff --git a/src/processor/update_admin.rs b/src/processor/update_admin.rs index 8c7646b..3e4df75 100644 --- a/src/processor/update_admin.rs +++ b/src/processor/update_admin.rs @@ -11,10 +11,9 @@ pub fn process_update_admin<'a, 'info>( data: &[u8], ) -> ProgramResult { // Parse args - let args = bytemuck::try_from_bytes::(data) - .or(Err(ProgramError::InvalidInstructionData))?; + let args = UpdateAdminArgs::try_from_bytes(data)?; - // Validate accounts + // Load accounts let [signer, treasury_info] = accounts else { return Err(ProgramError::NotEnoughAccountKeys); }; @@ -23,7 +22,7 @@ pub fn process_update_admin<'a, 'info>( // Validate admin signer let mut treasury_data = treasury_info.data.borrow_mut(); - let mut treasury = bytemuck::try_from_bytes_mut::(&mut treasury_data).unwrap(); + let mut treasury = Treasury::try_from_bytes_mut(&mut treasury_data)?; if !treasury.admin.eq(&signer.key) { return Err(ProgramError::MissingRequiredSignature); } diff --git a/src/processor/update_difficulty.rs b/src/processor/update_difficulty.rs index 2765b6c..2968a3e 100644 --- a/src/processor/update_difficulty.rs +++ b/src/processor/update_difficulty.rs @@ -11,10 +11,9 @@ pub fn process_update_difficulty<'a, 'info>( data: &[u8], ) -> ProgramResult { // Parse args - let args = bytemuck::try_from_bytes::(data) - .or(Err(ProgramError::InvalidInstructionData))?; + let args = UpdateDifficultyArgs::try_from_bytes(data)?; - // Validate accounts + // Load accounts let [signer, treasury_info] = accounts else { return Err(ProgramError::NotEnoughAccountKeys); }; @@ -23,7 +22,7 @@ pub fn process_update_difficulty<'a, 'info>( // Validate admin signer let mut treasury_data = treasury_info.data.borrow_mut(); - let mut treasury = bytemuck::try_from_bytes_mut::(&mut treasury_data).unwrap(); + let mut treasury = Treasury::try_from_bytes_mut(&mut treasury_data)?; if !treasury.admin.eq(&signer.key) { return Err(ProgramError::MissingRequiredSignature); } diff --git a/src/state/bus.rs b/src/state/bus.rs index 9f75d3d..8ebc1ca 100644 --- a/src/state/bus.rs +++ b/src/state/bus.rs @@ -1,6 +1,6 @@ use bytemuck::{Pod, Zeroable}; -use crate::impl_to_bytes; +use crate::{impl_account_from_bytes, impl_to_bytes}; #[repr(C)] #[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)] @@ -16,3 +16,4 @@ pub struct Bus { } impl_to_bytes!(Bus); +impl_account_from_bytes!(Bus); diff --git a/src/state/hash.rs b/src/state/hash.rs index 9a55030..6e0b457 100644 --- a/src/state/hash.rs +++ b/src/state/hash.rs @@ -3,7 +3,7 @@ use std::mem::transmute; use bytemuck::{Pod, Zeroable}; use solana_program::keccak::{Hash as KeccakHash, HASH_BYTES}; -use crate::impl_to_bytes; +use crate::{impl_account_from_bytes, impl_to_bytes}; #[repr(C)] #[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)] @@ -24,3 +24,4 @@ impl From for KeccakHash { } impl_to_bytes!(Hash); +impl_account_from_bytes!(Hash); diff --git a/src/state/proof.rs b/src/state/proof.rs index 7707b9d..c949158 100644 --- a/src/state/proof.rs +++ b/src/state/proof.rs @@ -1,7 +1,7 @@ use bytemuck::{Pod, Zeroable}; use solana_program::pubkey::Pubkey; -use crate::impl_to_bytes; +use crate::{impl_account_from_bytes, impl_to_bytes}; use super::Hash; @@ -28,3 +28,4 @@ pub struct Proof { } impl_to_bytes!(Proof); +impl_account_from_bytes!(Proof); diff --git a/src/state/treasury.rs b/src/state/treasury.rs index 1e5faa5..5e2f8bf 100644 --- a/src/state/treasury.rs +++ b/src/state/treasury.rs @@ -1,7 +1,7 @@ use bytemuck::{Pod, Zeroable}; use solana_program::pubkey::Pubkey; -use crate::impl_to_bytes; +use crate::{impl_account_from_bytes, impl_to_bytes}; use super::Hash; @@ -28,3 +28,4 @@ pub struct Treasury { } impl_to_bytes!(Treasury); +impl_account_from_bytes!(Treasury); diff --git a/src/utils.rs b/src/utils.rs index 3a9146e..a982106 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -42,3 +42,40 @@ macro_rules! impl_to_bytes { } }; } + +#[macro_export] +macro_rules! impl_account_from_bytes { + ($struct_name:ident) => { + impl $struct_name { + pub fn try_from_bytes( + data: &[u8], + ) -> Result<&Self, solana_program::program_error::ProgramError> { + bytemuck::try_from_bytes::(data).or(Err( + solana_program::program_error::ProgramError::InvalidAccountData, + )) + } + pub fn try_from_bytes_mut( + data: &mut [u8], + ) -> Result<&mut Self, solana_program::program_error::ProgramError> { + bytemuck::try_from_bytes_mut::(data).or(Err( + solana_program::program_error::ProgramError::InvalidAccountData, + )) + } + } + }; +} + +#[macro_export] +macro_rules! impl_instruction_from_bytes { + ($struct_name:ident) => { + impl $struct_name { + pub fn try_from_bytes( + data: &[u8], + ) -> Result<&Self, solana_program::program_error::ProgramError> { + bytemuck::try_from_bytes::(data).or(Err( + solana_program::program_error::ProgramError::InvalidInstructionData, + )) + } + } + }; +} diff --git a/tests/test_initialize.rs b/tests/test_initialize.rs index d72dea8..c4ad471 100644 --- a/tests/test_initialize.rs +++ b/tests/test_initialize.rs @@ -89,7 +89,7 @@ async fn test_initialize() { for i in 0..BUS_COUNT { let bus_account = banks.get_account(bus_pdas[i].0).await.unwrap().unwrap(); assert_eq!(bus_account.owner, ore::id()); - let bus = bytemuck::try_from_bytes::(&bus_account.data).unwrap(); + let bus = Bus::try_from_bytes(&bus_account.data).unwrap(); assert_eq!(bus.bump as u8, bus_pdas[i].1); assert_eq!(bus.id as u8, i as u8); assert_eq!(bus.available_rewards, 0); @@ -98,7 +98,7 @@ async fn test_initialize() { // Test treasury state let treasury_account = banks.get_account(treasury_pda.0).await.unwrap().unwrap(); assert_eq!(treasury_account.owner, ore::id()); - let treasury = bytemuck::try_from_bytes::(&treasury_account.data).unwrap(); + let treasury = Treasury::try_from_bytes(&treasury_account.data).unwrap(); assert_eq!(treasury.bump as u8, treasury_pda.1); assert_eq!(treasury.admin, payer.pubkey()); assert_eq!(treasury.difficulty, INITIAL_DIFFICULTY.into()); diff --git a/tests/test_mine.rs b/tests/test_mine.rs index afacb91..f88705c 100644 --- a/tests/test_mine.rs +++ b/tests/test_mine.rs @@ -48,12 +48,12 @@ async fn test_mine() { // Assert proof state let proof_account = banks.get_account(proof_pda.0).await.unwrap().unwrap(); assert_eq!(proof_account.owner, ore::id()); - let proof = bytemuck::try_from_bytes::(&proof_account.data).unwrap(); + let proof = Proof::try_from_bytes(&proof_account.data).unwrap(); // Assert proof state let treasury_pda = Pubkey::find_program_address(&[TREASURY], &ore::id()); let treasury_account = banks.get_account(treasury_pda.0).await.unwrap().unwrap(); - let treasury = bytemuck::try_from_bytes::(&treasury_account.data).unwrap(); + let treasury = Treasury::try_from_bytes(&treasury_account.data).unwrap(); // Find next hash let (next_hash, nonce) = find_next_hash( @@ -72,6 +72,9 @@ async fn test_mine() { AccountMeta::new(bus_pda.0, false), AccountMeta::new(proof_pda.0, false), AccountMeta::new_readonly(treasury_pda.0, false), + // AccountMeta::new(treasury_pda.0, false), + // AccountMeta::new(proof_pda.0, false), + // AccountMeta::new(bus_pda.0, false), AccountMeta::new_readonly(sysvar::slot_hashes::id(), false), ], data: [ diff --git a/tests/test_reset.rs b/tests/test_reset.rs index 1a71f7f..3d73df8 100644 --- a/tests/test_reset.rs +++ b/tests/test_reset.rs @@ -74,7 +74,7 @@ async fn test_reset() { for i in 0..BUS_COUNT { let bus_account = banks.get_account(bus_pdas[i].0).await.unwrap().unwrap(); assert_eq!(bus_account.owner, ore::id()); - let bus = bytemuck::try_from_bytes::(&bus_account.data).unwrap(); + let bus = Bus::try_from_bytes(&bus_account.data).unwrap(); println!( "Bus {:?} {:?} {:?}", bus_pdas[i].0, @@ -89,7 +89,7 @@ async fn test_reset() { // Test treasury state let treasury_account = banks.get_account(treasury_pda.0).await.unwrap().unwrap(); assert_eq!(treasury_account.owner, ore::id()); - let treasury = bytemuck::try_from_bytes::(&treasury_account.data).unwrap(); + let treasury = Treasury::try_from_bytes(&treasury_account.data).unwrap(); assert_eq!(treasury.bump as u8, treasury_pda.1); assert_eq!( treasury.admin,