|
| 1 | +module |
| 2 | + |
| 3 | +public import Lean |
| 4 | +import Binary.Basic |
| 5 | +import Binary.Get |
| 6 | +public import Binary.Hex |
| 7 | + |
| 8 | +public meta section |
| 9 | + |
| 10 | +namespace Binary |
| 11 | + |
| 12 | +open Lean Meta Elab Parser Term |
| 13 | + |
| 14 | +declare_syntax_cat get_proc (behavior := symbol) |
| 15 | + |
| 16 | +syntax get_proc_ascription_types := |
| 17 | + &"UInt8" <|> |
| 18 | + &"UInt16" <|> |
| 19 | + &"UInt32" <|> |
| 20 | + &"UInt64" <|> |
| 21 | + &"Int8" <|> |
| 22 | + &"Int16" <|> |
| 23 | + &"Int32" <|> |
| 24 | + &"Int64" <|> |
| 25 | + &"Float32" <|> |
| 26 | + &"Float" |
| 27 | + |
| 28 | +syntax get_proc_ascription_bytes := &"bytes" term |
| 29 | +syntax get_proc_ascription := " : " (get_proc_ascription_bytes <|> (get_proc_ascription_types (" < " <|> " > ")?) <|> term) |
| 30 | + |
| 31 | +syntax Parser.ident get_proc_ascription : get_proc |
| 32 | +syntax Parser.ident " ← " term : get_proc |
| 33 | +syntax &"hex " Hex.hexStr : get_proc |
| 34 | +syntax &"yield " term : get_proc |
| 35 | +syntax num get_proc_ascription : get_proc |
| 36 | + |
| 37 | +syntax (name := getProcStx) "get!" "{" get_proc,*,? "}" : term |
| 38 | + |
| 39 | +private def getAscriptionNumeralType (le? : Option Bool) : TSyntax ``get_proc_ascription_types → TermElabM (TSyntax `term) := fun stx => do |
| 40 | + match stx with |
| 41 | + | `(get_proc_ascription_types| UInt8) => ``(getThe UInt8) |
| 42 | + | `(get_proc_ascription_types| Int8) => ``(getThe Int8) |
| 43 | + | `(get_proc_ascription_types| $x) => |
| 44 | + let ty := x.raw[0][0].getAtomVal |
| 45 | + let t := mkIdentFrom x (Name.mkStr1 ty) |
| 46 | + let l := mkIdentFrom x (Name.str `Binary.Primitive.LE s!"instDecode{ty}") |
| 47 | + let b := mkIdentFrom x (Name.str `Binary.Primitive.BE s!"instDecode{ty}") |
| 48 | + match le? with |
| 49 | + | .none => ``(getThe $t) |
| 50 | + | .some true => ``(@get _ $l) |
| 51 | + | .some false => ``(@get _ $b) |
| 52 | + |
| 53 | +private def getAscription : TSyntax ``get_proc_ascription → TermElabM (TSyntax `term) := fun stx => do |
| 54 | + match stx with |
| 55 | + | `(get_proc_ascription| : $bs:get_proc_ascription_bytes) => |
| 56 | + match bs with |
| 57 | + | `(get_proc_ascription_bytes| bytes $len) => |
| 58 | + ``(get_bytes $len) |
| 59 | + | _ => throwUnsupportedSyntax |
| 60 | + | `(get_proc_ascription| : $type:get_proc_ascription_types $[$tk?]?) => |
| 61 | + let le? := tk?.map fun x => x.raw.getAtomVal.trim == " < " |
| 62 | + getAscriptionNumeralType le? type |
| 63 | + | `(get_proc_ascription| : $type:term) => |
| 64 | + let type' ← elabType type |
| 65 | + let instType := Expr.app (Expr.const ``Decode []) type' |
| 66 | + let .some _ ← synthInstance? instType | throwErrorAt type "failed to synthesize instance {instType}" |
| 67 | + ``(getThe $type) |
| 68 | + | _ => throwUnsupportedSyntax |
| 69 | + |
| 70 | +private def getFileLoc (pos : String.Pos.Raw) : MetaM String := do |
| 71 | + let map ← getFileMap |
| 72 | + let pos := map.toPosition pos |
| 73 | + let fileName ← getFileName |
| 74 | + return s!"{fileName}:{pos.line}:{pos.column}: " |
| 75 | + |
| 76 | +@[term_elab getProcStx] |
| 77 | +public def elabGetProcStx : TermElab := fun stx type? => do |
| 78 | + let `(getProcStx| get! { $body,* }) := stx | throwUnsupportedSyntax |
| 79 | + let es := body.getElems |
| 80 | + let mut ns := #[] |
| 81 | + let mut ts := #[] |
| 82 | + for e in es, i in List.range es.size do |
| 83 | + match e with |
| 84 | + | `(get_proc| $x:ident $ascr) => |
| 85 | + let a ← getAscription ascr |
| 86 | + let s ← `(doSeqItem| let $x ← $a:term) |
| 87 | + ns := ns.push x |
| 88 | + ts := ts.push s |
| 89 | + | `(get_proc| $x:ident ← $action) => |
| 90 | + let s ← `(doSeqItem| let $x ← $action:term) |
| 91 | + ns := ns.push x |
| 92 | + ts := ts.push s |
| 93 | + | `(get_proc| hex $hex:hexStr) => |
| 94 | + let vs ← Hex.elabHexStr hex.raw[0][0].getAtomVal |
| 95 | + if vs.isEmpty then |
| 96 | + continue |
| 97 | + let pos? ← liftM <| hex.raw.getPos?.mapM getFileLoc |
| 98 | + let posStr := quote <| pos?.getD "" |
| 99 | + let f := fun (v : UInt8) => `(doSeqItem| let $(quote v.toNat):num ← getThe UInt8 | throw (DecodeError.userError s!"{$posStr:str}hex literal assertion failed")) |
| 100 | + let rs ← vs.mapM f |
| 101 | + ts := ts.append rs |
| 102 | + | `(get_proc| $n:num $ascr) => |
| 103 | + let a ← getAscription ascr |
| 104 | + let pos? ← liftM <| n.raw.getPos?.mapM getFileLoc |
| 105 | + let posStr := quote <| pos?.getD "" |
| 106 | + let s ← `(doSeqItem| let $n:num ← $a:term | throw (DecodeError.userError s!"{$posStr:str}numeral literal assertion failed")) |
| 107 | + ts := ts.push s |
| 108 | + | `(get_proc| yield $r) => |
| 109 | + if i != es.size - 1 then |
| 110 | + throwErrorAt e "yield must be the last element" |
| 111 | + let s ← `(doSeqItem| return $r) |
| 112 | + ts := ts.push s |
| 113 | + | _ => throwUnsupportedSyntax |
| 114 | + let s ← `(do |
| 115 | + $ts* |
| 116 | + ) |
| 117 | + withMacroExpansion stx s do |
| 118 | + elabTerm s type? |
| 119 | + |
| 120 | +end Binary |
| 121 | + |
| 122 | +end |
0 commit comments