@@ -5,9 +5,10 @@ use std::{
55} ;
66
77use crate :: {
8- error:: { Error , Result } ,
8+ error:: { CorruptedDataError , CorruptedDataKind , Error , Result } ,
99 Either ,
1010} ;
11+ use crc32fast:: Hasher ;
1112use layout:: { CellPointerFlags , CellPointerMetadata , PageHeader } ;
1213use serde:: { de:: DeserializeOwned , Serialize } ;
1314use spec:: { LocationOffset , CELL_POINTER_SIZE , PAGE_FREE_SPACE_BYTE , PAGE_HEADER_SIZE , PAGE_SIZE } ;
@@ -35,28 +36,45 @@ impl<T: Serialize + DeserializeOwned + PartialOrd + Ord + Clone> Page<T> {
3536 upper,
3637 lower,
3738 special,
39+ checksum : 0 ,
3840 } ;
3941
4042 let header_bytes = bincode:: serialize ( & header) . map_err ( Error :: SerializeError ) ?;
4143 io. seek ( SeekFrom :: Start ( 0 ) ) . map_err ( Error :: IoError ) ?;
4244 io. write ( & header_bytes) . map_err ( Error :: IoError ) ?;
4345
44- Ok ( Page {
46+ let mut page : Page < T > = Page {
4547 header,
4648 io,
4749 _t : PhantomData ,
48- } )
50+ } ;
51+
52+ // update checksum
53+ page. write_header ( ) ?;
54+
55+ Ok ( page)
4956 }
5057
5158 pub fn open ( data : [ u8 ; PAGE_SIZE as usize ] ) -> Result < Self > {
5259 let mut io = Cursor :: new ( data) ;
5360 let header = Self :: read_header ( & mut io) ?;
5461
55- Ok ( Self {
62+ let page = Self {
5663 io,
57- header,
64+ header : header . clone ( ) ,
5865 _t : PhantomData ,
59- } )
66+ } ;
67+
68+ let checksum = page. checksum ( ) ;
69+
70+ if header. checksum != checksum {
71+ return Err ( Error :: CorruptedData ( CorruptedDataError {
72+ kind : CorruptedDataKind :: ChecksumNotMatch ,
73+ message : "checksum does not match" . to_string ( ) ,
74+ } ) ) ;
75+ }
76+
77+ Ok ( page)
6078 }
6179
6280 pub fn write ( & mut self , data : T ) -> Result < ( LocationOffset , LocationOffset ) > {
@@ -442,6 +460,9 @@ impl<T: Serialize + DeserializeOwned + PartialOrd + Ord + Clone> Page<T> {
442460 }
443461
444462 fn write_header ( & mut self ) -> Result < ( ) > {
463+ let checksum = self . checksum ( ) ;
464+ self . header . checksum = checksum;
465+
445466 let buffer = bincode:: serialize ( & self . header ) . map_err ( Error :: SerializeError ) ?;
446467
447468 self . io . seek ( SeekFrom :: Start ( 0 ) ) . map_err ( Error :: IoError ) ?;
@@ -458,6 +479,12 @@ impl<T: Serialize + DeserializeOwned + PartialOrd + Ord + Clone> Page<T> {
458479
459480 bincode:: deserialize ( & buffer) . map_err ( Error :: SerializeError )
460481 }
482+
483+ fn checksum ( & self ) -> u32 {
484+ let mut hasher = Hasher :: new ( ) ;
485+ hasher. update ( & self . io . get_ref ( ) [ PAGE_HEADER_SIZE ..] ) ;
486+ hasher. finalize ( )
487+ }
461488}
462489
463490#[ cfg( test) ]
@@ -589,4 +616,17 @@ mod page_tests {
589616
590617 assert_eq ! ( value, 90 ) ;
591618 }
619+
620+ #[ test]
621+ fn page_checksum ( ) {
622+ let mut page = Page :: < u32 > :: create ( 0 ) . unwrap ( ) ;
623+
624+ page. write ( 99 ) . unwrap ( ) ;
625+
626+ let mut page_bytes = page. to_bytes ( ) . unwrap ( ) ;
627+ //change a random byte
628+ page_bytes[ 26 ] = 2u8 ;
629+
630+ Page :: < u32 > :: open ( page_bytes) . err ( ) . unwrap ( ) ;
631+ }
592632}
0 commit comments