@@ -4,13 +4,16 @@ import {
44 ExpressionStatement ,
55 ExternalReferenceType ,
66 FunctionCall ,
7+ Identifier ,
78 Literal ,
9+ LiteralKind ,
810 Return ,
911} from 'solc-typed-ast' ;
1012import { AST } from '../../ast/ast' ;
11- import { CairoAssert } from '../../ast/cairoNodes' ;
1213import { ASTMapper } from '../../ast/mapper' ;
14+ import { cloneASTNode } from '../../utils/cloning' ;
1315import { createBoolLiteral } from '../../utils/nodeTemplates' ;
16+ import { toHexString } from '../../utils/utils' ;
1417
1518export class Require extends ASTMapper {
1619 // Function to add passes that should have been run before this pass
@@ -28,7 +31,7 @@ export class Require extends ASTMapper {
2831 return ;
2932 }
3033
31- // Since the cairoAssert is not null, we have a require/revert/assert function call at hand
34+ // Since cairoAssert is not null, this is a require/revert/assert function call
3235 assert ( expressionNode instanceof FunctionCall ) ;
3336
3437 ast . replaceNode ( node , cairoAssert ) ;
@@ -44,42 +47,49 @@ export class Require extends ASTMapper {
4447 }
4548
4649 requireToCairoAssert ( expression : Expression | undefined , ast : AST ) : ExpressionStatement | null {
47- if ( ! ( expression instanceof FunctionCall ) ) return null ;
48- if ( expression . vFunctionCallType !== ExternalReferenceType . Builtin ) {
49- return null ;
50- }
50+ if (
51+ expression instanceof FunctionCall &&
52+ expression . vFunctionCallType === ExternalReferenceType . Builtin &&
53+ [ 'assert' , 'require' , 'revert' ] . includes ( expression . vIdentifier )
54+ ) {
55+ // TODO: The identifier node generated by the solc-typed-ast has different typestrings
56+ // and referencedDeclaration number for assert, require and revert Solidity functions.
57+ // Check typestring when updating solc-typed-ast version.
58+ const assertIdentifier = cloneASTNode ( expression . vExpression , ast ) ;
59+ assert ( assertIdentifier instanceof Identifier ) ;
60+ assertIdentifier . name = 'assert' ;
5161
52- if ( expression . vIdentifier === 'require' || expression . vIdentifier === 'assert' ) {
53- const requireMessage =
54- expression . vArguments [ 1 ] instanceof Literal ? expression . vArguments [ 1 ] . value : null ;
62+ const args : Expression [ ] = [ ] ;
63+ if ( expression . vIdentifier === 'revert' ) args . push ( createBoolLiteral ( false , ast ) ) ;
64+ args . push ( ...expression . vArguments ) ;
65+ if ( args . length < 2 ) {
66+ const message = 'Assertion error' ;
67+ args . push (
68+ new Literal (
69+ ast . reserveId ( ) ,
70+ '' ,
71+ `literal_string "${ message } "` ,
72+ LiteralKind . String ,
73+ toHexString ( message ) ,
74+ message ,
75+ ) ,
76+ ) ;
77+ }
5578
5679 return new ExpressionStatement (
5780 ast . reserveId ( ) ,
5881 expression . src ,
59- new CairoAssert (
82+ new FunctionCall (
6083 ast . reserveId ( ) ,
61- expression . src ,
62- expression . vArguments [ 0 ] ,
63- requireMessage ,
64- expression . raw ,
65- ) ,
66- ) ;
67- } else if ( expression . vIdentifier === 'revert' ) {
68- const revertMessage =
69- expression . vArguments [ 0 ] instanceof Literal ? expression . vArguments [ 0 ] . value : null ;
70-
71- return new ExpressionStatement (
72- ast . reserveId ( ) ,
73- expression . src ,
74- new CairoAssert (
75- ast . reserveId ( ) ,
76- expression . src ,
77- createBoolLiteral ( false , ast ) ,
78- revertMessage ,
79- expression . raw ,
84+ '' ,
85+ expression . typeString ,
86+ expression . kind ,
87+ assertIdentifier ,
88+ args ,
8089 ) ,
8190 ) ;
8291 }
92+
8393 return null ;
8494 }
8595}
0 commit comments