diff --git a/lib/prism/translation/ripper.rb b/lib/prism/translation/ripper.rb index bbfa1f4d05d175..2f66bab97ee5b8 100644 --- a/lib/prism/translation/ripper.rb +++ b/lib/prism/translation/ripper.rb @@ -606,6 +606,9 @@ def parse # alias foo bar # ^^^^^^^^^^^^^ def visit_alias_method_node(node) + bounds(node.keyword_loc) + on_kw("alias") + new_name = visit(node.new_name) old_name = visit(node.old_name) @@ -616,6 +619,9 @@ def visit_alias_method_node(node) # alias $foo $bar # ^^^^^^^^^^^^^^^ def visit_alias_global_variable_node(node) + bounds(node.keyword_loc) + on_kw("alias") + new_name = visit_alias_global_variable_node_value(node.new_name) old_name = visit_alias_global_variable_node_value(node.old_name) @@ -661,6 +667,10 @@ def visit_alternation_pattern_node(node) # ^^^^^^^ def visit_and_node(node) left = visit(node.left) + if node.operator == "and" + bounds(node.operator_loc) + on_kw("and") + end right = visit(node.right) bounds(node.location) @@ -887,8 +897,18 @@ def visit_back_reference_read_node(node) # begin end # ^^^^^^^^^ def visit_begin_node(node) + if node.begin_keyword_loc + bounds(node.begin_keyword_loc) + on_kw("begin") + end + clauses = visit_begin_node_clauses(node.begin_keyword_loc, node, false) + if node.end_keyword_loc + bounds(node.end_keyword_loc) + on_kw("end") + end + bounds(node.location) on_begin(clauses) end @@ -909,6 +929,9 @@ def visit_begin_node(node) rescue_clause = visit(node.rescue_clause) else_clause = unless (else_clause_node = node.else_clause).nil? + bounds(else_clause_node.else_keyword_loc) + on_kw("else") + else_statements = if else_clause_node.statements.nil? [nil] @@ -966,6 +989,11 @@ def visit_block_node(node) braces = node.opening == "{" parameters = visit(node.parameters) + unless braces + bounds(node.opening_loc) + on_kw("do") + end + body = case node.body when nil @@ -987,6 +1015,11 @@ def visit_block_node(node) raise end + unless braces + bounds(node.closing_loc) + on_kw("end") + end + if braces bounds(node.location) on_brace_block(parameters, body) @@ -1037,6 +1070,9 @@ def visit_block_parameters_node(node) # break foo # ^^^^^^^^^ def visit_break_node(node) + bounds(node.keyword_loc) + on_kw("break") + if node.arguments.nil? bounds(node.location) on_break(on_args_new) @@ -1103,6 +1139,9 @@ def visit_call_node(node) on_unary(node.name, receiver) when :! if node.message == "not" + bounds(node.message_loc) + on_kw("not") + receiver = if !node.receiver.is_a?(ParenthesesNode) || !node.receiver.body.nil? visit(node.receiver) @@ -1347,10 +1386,21 @@ def visit_capture_pattern_node(node) # case foo; when bar; end # ^^^^^^^^^^^^^^^^^^^^^^^ def visit_case_node(node) + bounds(node.case_keyword_loc) + on_kw("case") + predicate = visit(node.predicate) + visited_conditions = node.conditions.map { |condition| visit(condition) } + visited_else_clause = visit(node.else_clause) + + if !node.else_clause + bounds(node.end_keyword_loc) + on_kw("end") + end + clauses = - node.conditions.reverse_each.inject(visit(node.else_clause)) do |current, condition| - on_when(*visit(condition), current) + visited_conditions.reverse_each.inject(visited_else_clause) do |current, condition| + on_when(*condition, current) end bounds(node.location) @@ -1360,10 +1410,23 @@ def visit_case_node(node) # case foo; in bar; end # ^^^^^^^^^^^^^^^^^^^^^ def visit_case_match_node(node) + bounds(node.case_keyword_loc) + on_kw("case") + predicate = visit(node.predicate) + visited_conditions = node.conditions.map do | condition| + visit(condition) + end + visited_else_clause = visit(node.else_clause) + + if !node.else_clause + bounds(node.end_keyword_loc) + on_kw("end") + end + clauses = - node.conditions.reverse_each.inject(visit(node.else_clause)) do |current, condition| - on_in(*visit(condition), current) + visited_conditions.reverse_each.inject(visited_else_clause) do |current, condition| + on_in(*condition, current) end bounds(node.location) @@ -1373,6 +1436,9 @@ def visit_case_match_node(node) # class Foo; end # ^^^^^^^^^^^^^^ def visit_class_node(node) + bounds(node.class_keyword_loc) + on_kw("class") + constant_path = if node.constant_path.is_a?(ConstantReadNode) bounds(node.constant_path.location) @@ -1384,6 +1450,9 @@ def visit_class_node(node) superclass = visit(node.superclass) bodystmt = visit_body_node(node.superclass&.location || node.constant_path.location, node.body, node.superclass.nil?) + bounds(node.end_keyword_loc) + on_kw("end") + bounds(node.location) on_class(constant_path, superclass, bodystmt) end @@ -1631,6 +1700,9 @@ def visit_constant_path_target_node(node) # def self.foo; end # ^^^^^^^^^^^^^^^^^ def visit_def_node(node) + bounds(node.def_keyword_loc) + on_kw("def") + receiver = visit(node.receiver) operator = if !node.operator_loc.nil? @@ -1664,6 +1736,11 @@ def visit_def_node(node) on_bodystmt(body, nil, nil, nil) end + if node.end_keyword_loc + bounds(node.end_keyword_loc) + on_kw("end") + end + bounds(node.location) if receiver on_defs(receiver, operator, name, parameters, bodystmt) @@ -1678,6 +1755,9 @@ def visit_def_node(node) # defined?(a) # ^^^^^^^^^^^ def visit_defined_node(node) + bounds(node.keyword_loc) + on_kw("defined?") + expression = visit(node.value) # Very weird circumstances here where something like: @@ -1700,6 +1780,9 @@ def visit_defined_node(node) # if foo then bar else baz end # ^^^^^^^^^^^^ def visit_else_node(node) + bounds(node.else_keyword_loc) + on_kw("else") + statements = if node.statements.nil? [nil] @@ -1709,8 +1792,12 @@ def visit_else_node(node) body end + else_statements = visit_statements_node_body(statements) + + bounds(node.end_keyword_loc) + on_kw("end") bounds(node.location) - on_else(visit_statements_node_body(statements)) + on_else(else_statements) end # "foo #{bar}" @@ -1748,6 +1835,9 @@ def visit_embedded_variable_node(node) # Visit an EnsureNode node. def visit_ensure_node(node) + bounds(node.ensure_keyword_loc) + on_kw("ensure") + statements = if node.statements.nil? [nil] @@ -1818,8 +1908,18 @@ def visit_float_node(node) # for foo in bar do end # ^^^^^^^^^^^^^^^^^^^^^ def visit_for_node(node) + bounds(node.for_keyword_loc) + on_kw("for") + index = visit(node.index) + bounds(node.in_keyword_loc) + on_kw("in") + collection = visit(node.collection) + if node.do_keyword_loc + bounds(node.do_keyword_loc) + on_kw("do") + end statements = if node.statements.nil? bounds(node.location) @@ -1828,6 +1928,9 @@ def visit_for_node(node) visit(node.statements) end + bounds(node.end_keyword_loc) + on_kw("end") + bounds(node.location) on_for(index, collection, statements) end @@ -1852,6 +1955,9 @@ def visit_forwarding_parameter_node(node) # super {} # ^^^^^^^^ def visit_forwarding_super_node(node) + bounds(node.keyword_loc) + on_kw("super") + if node.block.nil? bounds(node.location) on_zsuper @@ -2001,7 +2107,13 @@ def visit_if_node(node) bounds(node.location) on_ifop(predicate, truthy, falsy) elsif node.statements.nil? || (node.predicate.location.start_offset < node.statements.location.start_offset) + bounds(node.if_keyword_loc) + on_kw(node.if_keyword) predicate = visit(node.predicate) + if node.then_keyword_loc && node.then_keyword != "?" + bounds(node.then_keyword_loc) + on_kw("then") + end statements = if node.statements.nil? bounds(node.location) @@ -2011,6 +2123,11 @@ def visit_if_node(node) end subsequent = visit(node.subsequent) + if node.end_keyword_loc && !node.subsequent + bounds(node.end_keyword_loc) + on_kw("end") + end + bounds(node.location) if node.if_keyword == "if" on_if(predicate, statements, subsequent) @@ -2019,6 +2136,8 @@ def visit_if_node(node) end else statements = visit(node.statements.body.first) + bounds(node.if_keyword_loc) + on_kw(node.if_keyword) predicate = visit(node.predicate) bounds(node.location) @@ -2050,7 +2169,14 @@ def visit_in_node(node) # This is a special case where we're not going to call on_in directly # because we don't have access to the subsequent. Instead, we'll return # the component parts and let the parent node handle it. + bounds(node.in_loc) + on_kw("in") + pattern = visit_pattern_node(node.pattern) + if node.then_loc + bounds(node.then_loc) + on_kw("then") + end statements = if node.statements.nil? bounds(node.location) @@ -2386,6 +2512,11 @@ def visit_lambda_node(node) on_tlambeg(node.opening) end + unless braces + bounds(node.opening_loc) + on_kw("do") + end + body = case node.body when nil @@ -2407,6 +2538,11 @@ def visit_lambda_node(node) raise end + unless braces + bounds(node.closing_loc) + on_kw("end") + end + bounds(node.location) on_lambda(parameters, body) end @@ -2497,6 +2633,8 @@ def visit_match_last_line_node(node) # ^^^^^^^^^^ def visit_match_predicate_node(node) value = visit(node.value) + bounds(node.operator_loc) + on_kw("in") pattern = on_in(visit_pattern_node(node.pattern), nil, nil) on_case(value, pattern) @@ -2526,6 +2664,9 @@ def visit_error_recovery_node(node) # module Foo; end # ^^^^^^^^^^^^^^^ def visit_module_node(node) + bounds(node.module_keyword_loc) + on_kw("module") + constant_path = if node.constant_path.is_a?(ConstantReadNode) bounds(node.constant_path.location) @@ -2536,6 +2677,9 @@ def visit_module_node(node) bodystmt = visit_body_node(node.constant_path.location, node.body, true) + bounds(node.end_keyword_loc) + on_kw("end") + bounds(node.location) on_module(constant_path, bodystmt) end @@ -2617,6 +2761,9 @@ def visit_multi_write_node(node) # next foo # ^^^^^^^^ def visit_next_node(node) + bounds(node.keyword_loc) + on_kw("next") + if node.arguments.nil? bounds(node.location) on_next(on_args_new) @@ -2638,6 +2785,8 @@ def visit_nil_node(node) # def foo(&nil); end # ^^^^ def visit_no_block_parameter_node(node) + bounds(node.keyword_loc) + on_kw("nil") bounds(node.location) on_blockarg(:nil) end @@ -2645,6 +2794,8 @@ def visit_no_block_parameter_node(node) # def foo(**nil); end # ^^^^^ def visit_no_keywords_parameter_node(node) + bounds(node.keyword_loc) + on_kw("nil") bounds(node.location) on_nokw_param(nil) @@ -2687,6 +2838,10 @@ def visit_optional_parameter_node(node) # ^^^^^^ def visit_or_node(node) left = visit(node.left) + if node.operator == "or" + bounds(node.operator_loc) + on_kw("or") + end right = visit(node.right) bounds(node.location) @@ -2752,6 +2907,9 @@ def visit_pinned_variable_node(node) # END {} # ^^^^^^ def visit_post_execution_node(node) + bounds(node.keyword_loc) + on_kw("END") + statements = if node.statements.nil? bounds(node.location) @@ -2767,6 +2925,9 @@ def visit_post_execution_node(node) # BEGIN {} # ^^^^^^^^ def visit_pre_execution_node(node) + bounds(node.keyword_loc) + on_kw("BEGIN") + statements = if node.statements.nil? bounds(node.location) @@ -2813,6 +2974,7 @@ def visit_rational_node(node) # ^^^^ def visit_redo_node(node) bounds(node.location) + on_kw("redo") on_redo end @@ -2855,6 +3017,9 @@ def visit_required_parameter_node(node) # foo rescue bar # ^^^^^^^^^^^^^^ def visit_rescue_modifier_node(node) + bounds(node.keyword_loc) + on_kw("rescue") + expression = visit_write_value(node.expression) rescue_expression = visit(node.rescue_expression) @@ -2865,6 +3030,9 @@ def visit_rescue_modifier_node(node) # begin; rescue; end # ^^^^^^^ def visit_rescue_node(node) + bounds(node.keyword_loc) + on_kw("rescue") + exceptions = case node.exceptions.length when 0 @@ -2936,6 +3104,7 @@ def visit_rest_parameter_node(node) # ^^^^^ def visit_retry_node(node) bounds(node.location) + on_kw("retry") on_retry end @@ -2945,6 +3114,9 @@ def visit_retry_node(node) # return 1 # ^^^^^^^^ def visit_return_node(node) + bounds(node.keyword_loc) + on_kw("return") + if node.arguments.nil? bounds(node.location) on_return0 @@ -2971,9 +3143,15 @@ def visit_shareable_constant_node(node) # class << self; end # ^^^^^^^^^^^^^^^^^^ def visit_singleton_class_node(node) + bounds(node.class_keyword_loc) + on_kw("class") + expression = visit(node.expression) bodystmt = visit_body_node(node.body&.location || node.end_keyword_loc, node.body) + bounds(node.end_keyword_loc) + on_kw("end") + bounds(node.location) on_sclass(expression, bodystmt) end @@ -3180,6 +3358,9 @@ def visit_string_node(node) # super(foo) # ^^^^^^^^^^ def visit_super_node(node) + bounds(node.keyword_loc) + on_kw("super") + arguments, block, has_ripper_block = visit_call_node_arguments(node.arguments, node.block, trailing_comma?(node.arguments&.location || node.location, node.rparen_loc || node.location)) if !node.lparen_loc.nil? @@ -3233,6 +3414,9 @@ def visit_true_node(node) # undef foo # ^^^^^^^^^ def visit_undef_node(node) + bounds(node.keyword_loc) + on_kw("undef") + names = visit_all(node.names) bounds(node.location) @@ -3246,7 +3430,13 @@ def visit_undef_node(node) # ^^^^^^^^^^^^^^ def visit_unless_node(node) if node.statements.nil? || (node.predicate.location.start_offset < node.statements.location.start_offset) + bounds(node.keyword_loc) + on_kw("unless") predicate = visit(node.predicate) + if node.then_keyword_loc + bounds(node.then_keyword_loc) + on_kw("then") + end statements = if node.statements.nil? bounds(node.location) @@ -3256,10 +3446,17 @@ def visit_unless_node(node) end else_clause = visit(node.else_clause) + if node.end_keyword_loc && !node.else_clause + bounds(node.end_keyword_loc) + on_kw("end") + end + bounds(node.location) on_unless(predicate, statements, else_clause) else statements = visit(node.statements.body.first) + bounds(node.keyword_loc) + on_kw("unless") predicate = visit(node.predicate) bounds(node.location) @@ -3273,7 +3470,14 @@ def visit_unless_node(node) # bar until foo # ^^^^^^^^^^^^^ def visit_until_node(node) + bounds(node.keyword_loc) + on_kw("until") + if node.statements.nil? || (node.predicate.location.start_offset < node.statements.location.start_offset) + if node.do_keyword_loc + bounds(node.do_keyword_loc) + on_kw("do") + end predicate = visit(node.predicate) statements = if node.statements.nil? @@ -3283,6 +3487,11 @@ def visit_until_node(node) visit(node.statements) end + if node.closing_loc + bounds(node.closing_loc) + on_kw("end") + end + bounds(node.location) on_until(predicate, statements) else @@ -3300,7 +3509,14 @@ def visit_when_node(node) # This is a special case where we're not going to call on_when directly # because we don't have access to the subsequent. Instead, we'll return # the component parts and let the parent node handle it. + bounds(node.keyword_loc) + on_kw("when") + conditions = visit_arguments(node.conditions) + if node.then_keyword_loc + bounds(node.then_keyword_loc) + on_kw("then") + end statements = if node.statements.nil? bounds(node.location) @@ -3319,7 +3535,17 @@ def visit_when_node(node) # ^^^^^^^^^^^^^ def visit_while_node(node) if node.statements.nil? || (node.predicate.location.start_offset < node.statements.location.start_offset) + bounds(node.keyword_loc) + on_kw("while") + if node.do_keyword_loc + bounds(node.do_keyword_loc) + on_kw("do") + end predicate = visit(node.predicate) + if node.closing_loc + bounds(node.closing_loc) + on_kw("end") + end statements = if node.statements.nil? bounds(node.location) @@ -3332,6 +3558,8 @@ def visit_while_node(node) on_while(predicate, statements) else statements = visit(node.statements.body.first) + bounds(node.keyword_loc) + on_kw("while") predicate = visit(node.predicate) bounds(node.location) @@ -3367,6 +3595,9 @@ def visit_x_string_node(node) # yield 1 # ^^^^^^^ def visit_yield_node(node) + bounds(node.keyword_loc) + on_kw("yield") + if node.arguments.nil? && node.lparen_loc.nil? bounds(node.location) on_yield0 diff --git a/prism/templates/template.rb b/prism/templates/template.rb index 2e7d3b107f1a39..0fdeda561f3573 100755 --- a/prism/templates/template.rb +++ b/prism/templates/template.rb @@ -690,9 +690,9 @@ def locals "javascript/src/deserialize.js", "javascript/src/nodes.js", "javascript/src/visitor.js", - "java/api/target/generated-sources/java/org/ruby_lang/prism/Loader.java", - "java/api/target/generated-sources/java/org/ruby_lang/prism/Nodes.java", - "java/api/target/generated-sources/java/org/ruby_lang/prism/AbstractNodeVisitor.java", + "java/api/src/main/java-templates/org/ruby_lang/prism/Loader.java", + "java/api/src/main/java-templates/org/ruby_lang/prism/Nodes.java", + "java/api/src/main/java-templates/org/ruby_lang/prism/AbstractNodeVisitor.java", "lib/prism/compiler.rb", "lib/prism/dispatcher.rb", "lib/prism/dot_visitor.rb", diff --git a/test/prism/newline_test.rb b/test/prism/newline_test.rb index c8914b57dcfac8..97e698202df748 100644 --- a/test/prism/newline_test.rb +++ b/test/prism/newline_test.rb @@ -20,6 +20,7 @@ class NewlineTest < TestCase ruby/find_fixtures.rb ruby/find_test.rb ruby/parser_test.rb + ruby/ripper_test.rb ruby/ruby_parser_test.rb ] diff --git a/test/prism/ruby/ripper_test.rb b/test/prism/ruby/ripper_test.rb index 7274454e1b44ed..1d20bceb40d6dc 100644 --- a/test/prism/ruby/ripper_test.rb +++ b/test/prism/ruby/ripper_test.rb @@ -136,8 +136,93 @@ def test_lex_ignored_missing_heredoc_end end end - UNSUPPORTED_EVENTS = %i[comma ignored_nl kw label_end lbrace lbracket lparen nl op rbrace rbracket rparen semicolon sp words_sep ignored_sp] + # Events that are currently not emitted + UNSUPPORTED_EVENTS = %i[comma ignored_nl label_end lbrace lbracket lparen nl op rbrace rbracket rparen semicolon sp words_sep ignored_sp] SUPPORTED_EVENTS = Translation::Ripper::EVENTS - UNSUPPORTED_EVENTS + # Events that assert against their line/column + CHECK_LOCATION_EVENTS = %i[kw] + IGNORE_FOR_SORT_EVENTS = %i[ + stmts_new stmts_add bodystmt void_stmt + args_new args_add args_add_star args_add_block arg_paren method_add_arg + mlhs_new mlhs_add_star + word_new words_new symbols_new qwords_new qsymbols_new xstring_new regexp_new + words_add symbols_add qwords_add qsymbols_add + regexp_end tstring_end heredoc_end + call command fcall vcall + field aref_field var_field var_ref block_var ident params + string_content heredoc_dedent unary binary dyna_symbol + comment magic_comment embdoc embdoc_beg embdoc_end arg_ambiguous + ] + SORT_IGNORE = { + aref: [ + "blocks.txt", + "command_method_call.txt", + "whitequark/ruby_bug_13547.txt", + ], + assoc_new: [ + "case_in_hash_key.txt", + "whitequark/parser_bug_525.txt", + "whitequark/ruby_bug_11380.txt", + ], + bare_assoc_hash: [ + "case_in_hash_key.txt", + "method_calls.txt", + "whitequark/parser_bug_525.txt", + "whitequark/ruby_bug_11380.txt", + ], + brace_block: [ + "super.txt", + "unparser/corpus/literal/super.txt" + ], + command_call: [ + "blocks.txt", + "case_in_hash_key.txt", + "seattlerb/block_call_dot_op2_cmd_args_do_block.txt", + "seattlerb/block_call_operation_colon.txt", + "seattlerb/block_call_operation_dot.txt", + ], + const_path_field: [ + "seattlerb/const_2_op_asgn_or2.txt", + "seattlerb/const_op_asgn_or.txt", + "whitequark/const_op_asgn.txt", + ], + const_path_ref: ["unparser/corpus/literal/defs.txt"], + do_block: ["whitequark/super_block.txt"], + embexpr_end: ["seattlerb/str_interp_ternary_or_label.txt"], + rest_param: ["whitequark/send_lambda.txt"], + top_const_field: [ + "seattlerb/const_3_op_asgn_or.txt", + "seattlerb/const_op_asgn_and1.txt", + "seattlerb/const_op_asgn_and2.txt", + "whitequark/const_op_asgn.txt", + ], + mlhs_paren: ["unparser/corpus/literal/for.txt"], + mlhs_add: [ + "whitequark/for_mlhs.txt", + ], + kw: [ + "defined.txt", + "for.txt", + "seattlerb/block_kw__required.txt", + "seattlerb/case_in_42.txt", + "seattlerb/case_in_67.txt", + "seattlerb/case_in_86_2.txt", + "seattlerb/case_in_86.txt", + "seattlerb/case_in_hash_pat_paren_true.txt", + "seattlerb/flip2_env_lvar.txt", + "unless.txt", + "unparser/corpus/semantic/and.txt", + "whitequark/class.txt", + "whitequark/find_pattern.txt", + "whitequark/pattern_matching_hash.txt", + "whitequark/pattern_matching_implicit_array_match.txt", + "whitequark/pattern_matching_ranges.txt", + "whitequark/super_block.txt", + "write_command_operator.txt", + ], + } + SORT_IGNORE.default = [] + SORT_EVENTS = SUPPORTED_EVENTS - IGNORE_FOR_SORT_EVENTS module Events attr_reader :events @@ -147,9 +232,20 @@ def initialize(...) @events = [] end + def sorted_events + @events.select do |e,| + next false if e == :kw && @events.any? { |e,| e == :if_mod || e == :while_mod || e == :until_mod || e == :rescue || e == :rescue_mod || e == :while || e == :ensure } + SORT_EVENTS.include?(e) && !SORT_IGNORE[e].include?(filename) + end + end + SUPPORTED_EVENTS.each do |event| define_method(:"on_#{event}") do |*args| - @events << [event, *args.map(&:to_s)] + if CHECK_LOCATION_EVENTS.include?(event) + @events << [event, lineno, column, *args.map(&:to_s)] + else + @events << [event, *args.map(&:to_s)] + end super(*args) end end @@ -177,12 +273,14 @@ class ObjectEvents < Translation::Ripper object_events = ObjectEvents.new(source) assert_nothing_raised { object_events.parse } - ripper = RipperEvents.new(source) - prism = PrismEvents.new(source) + ripper = RipperEvents.new(source, fixture.path) + prism = PrismEvents.new(source, fixture.path) ripper.parse prism.parse - # This makes sure that the content is the same. Ordering is not correct for now. + # Check that the same events are emitted, regardless of order assert_equal(ripper.events.sort, prism.events.sort) + # Check a subset of events against the correct order + assert_equal(ripper.sorted_events, prism.sorted_events) end end diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 2cde7388fbe127..7394d3d96adf9a 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -2549,7 +2549,7 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard asm.cmp(klass, Opnd::Value(expected_class)); asm.jne(jit, side_exit); - } else if guard_type.is_subtype(types::String) { + } else if guard_type.is_subtype(types::TypedTData) { let side = side_exit(jit, state, GuardType(guard_type)); // Check special constant @@ -2560,13 +2560,15 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard asm.cmp(val, Qfalse.into()); asm.je(jit, side.clone()); + // Check the builtin type and RUBY_TYPED_FL_IS_TYPED_DATA with mask and compare let val = asm.load_mem(val); - let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); - let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); - asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64)); + let mask = RUBY_T_MASK.to_usize() | RUBY_TYPED_FL_IS_TYPED_DATA.to_usize(); + let expected = RUBY_T_DATA.to_usize() | RUBY_TYPED_FL_IS_TYPED_DATA.to_usize(); + let masked = asm.and(flags, mask.into()); + asm.cmp(masked, expected.into()); asm.jne(jit, side); - } else if guard_type.is_subtype(types::Array) { + } else if let Some(builtin_type) = guard_type.builtin_type_equivalent() { let side = side_exit(jit, state, GuardType(guard_type)); // Check special constant @@ -2577,11 +2579,11 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard asm.cmp(val, Qfalse.into()); asm.je(jit, side.clone()); + // Mask and check the builtin type let val = asm.load_mem(val); - let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); - asm.cmp(tag, Opnd::UImm(RUBY_T_ARRAY as u64)); + asm.cmp(tag, Opnd::UImm(builtin_type as u64)); asm.jne(jit, side); } else if guard_type.bit_equal(types::HeapBasicObject) { let side_exit = side_exit(jit, state, GuardType(guard_type)); diff --git a/zjit/src/codegen_tests.rs b/zjit/src/codegen_tests.rs index 40278d112e9a4b..7c8a3c758b1fbd 100644 --- a/zjit/src/codegen_tests.rs +++ b/zjit/src/codegen_tests.rs @@ -3623,6 +3623,97 @@ fn test_attr_accessor_getivar() { assert_snapshot!(assert_compiles("c = C.new; [test(c), test(c)]"), @"[4, 4]"); } +#[test] +fn test_getivar_t_data_then_string() { + // This is a regression test for a type confusion miscomp where + // we end up reading the fields object using an offset off of a + // string, assuming that it has a the same layout as a T_DATA object. + // At the time of writing the fields object of strings are stored + // in a global table, out-of-line of each string. + // The string and the thread end up sharing one shape ID. + set_call_threshold(2); + eval(r#" + module GetThousand + def test = @var1000 + end + class Thread + include GetThousand + end + class String + include GetThousand + end + OBJ = Thread.new { } + OBJ.join + STR = +'' + (0..1000).each do |i| + ivar_name = :"@var#{i}" + OBJ.instance_variable_set(ivar_name, i) + STR.instance_variable_set(ivar_name, i) + end + OBJ.test; OBJ.test # profile and compile for Thread (T_DATA) + "#); + assert_snapshot!(assert_compiles("[STR.test, STR.test]"), @"[1000, 1000]"); +} + +#[test] +fn test_getivar_t_object_then_string() { + // This test construct an object and a string that have the same set of ivars. + // They wouldn't share the same shape ID, though, and we rely on this fact in + // our guards. + set_call_threshold(2); + eval(r#" + module GetThousand + def test = @var1000 + end + class MyObject + include GetThousand + end + class String + include GetThousand + end + OBJ = MyObject.new + STR = +'' + (0..1000).each do |i| + ivar_name = :"@var#{i}" + OBJ.instance_variable_set(ivar_name, i) + STR.instance_variable_set(ivar_name, i) + end + OBJ.test; OBJ.test # profile and compile for MyObject + "#); + assert_snapshot!(assert_compiles("[STR.test, STR.test]"), @"[1000, 1000]"); +} + +#[test] +fn test_getivar_t_class_then_string() { + // This is a regression test for a type confusion miscomp where + // we end up reading the fields object using an offset off of a + // string, assuming that it has a the same layout as a T_CLASS object. + // At the time of writing the fields object of strings are stored + // in a global table, out-of-line of each string. + // The string and the class end up sharing one shape ID. + set_call_threshold(2); + eval(r#" + module GetThousand + def test = @var1000 + end + class MyClass + extend GetThousand + end + class String + include GetThousand + end + STR = +'' + (0..1000).each do |i| + ivar_name = :"@var#{i}" + MyClass.instance_variable_set(ivar_name, i) + STR.instance_variable_set(ivar_name, i) + end + p MyClass.test; p MyClass.test # profile and compile for MyClass + p STR.test + "#); + assert_snapshot!(assert_compiles("[STR.test, STR.test]"), @"[1000, 1000]"); +} + #[test] fn test_attr_accessor_setivar() { eval(" diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index ee9f7fa7e175ef..3c00275478523c 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -604,6 +604,7 @@ impl VALUE { unsafe { rb_jit_class_fields_embedded_p(self) } } + /// Typed `T_DATA` made from `TypedData_Make_Struct()` (e.g. Thread, ARGF) pub fn typed_data_p(self) -> bool { !self.special_const_p() && self.builtin_type() == RUBY_T_DATA && diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index e8db2b6af2649a..bc7be7498a3c81 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -2228,22 +2228,6 @@ impl<'a> FunctionPrinter<'a> { } } -/// Pretty printer for [`Function`]. -pub struct FunctionGraphvizPrinter<'a> { - fun: &'a Function, - ptr_map: PtrPrintMap, -} - -impl<'a> FunctionGraphvizPrinter<'a> { - pub fn new(fun: &'a Function) -> Self { - let mut ptr_map = PtrPrintMap::identity(); - if cfg!(test) { - ptr_map.map_ptrs = true; - } - Self { fun, ptr_map } - } -} - /// Union-Find (Disjoint-Set) is a data structure for managing disjoint sets that has an interface /// of two operations: /// @@ -4431,6 +4415,23 @@ impl Function { } } + /// This puts a guard that establishes the preconditon for [Self::load_ivar] + fn load_ivar_guard_type(&mut self, block: BlockId, recv: InsnId, recv_type: ProfiledType, state: InsnId) -> InsnId { + if recv_type.class().is_subclass_of(unsafe { rb_cClass }) == ClassRelationship::Subclass { + // Check class first since `Class < Module` + self.push_insn(block, Insn::GuardType { val: recv, guard_type: types::Class, state }) + } else if recv_type.class().is_subclass_of(unsafe { rb_cModule }) == ClassRelationship::Subclass { + self.push_insn(block, Insn::GuardType { val: recv, guard_type: types::Module, state }) + } else if recv_type.flags().is_typed_data() { + self.push_insn(block, Insn::GuardType { val: recv, guard_type: types::TypedTData, state }) + } else { + // HeapBasicObject is wider than T_OBJECT, but shapes for T_OBJECTs are in a pool of + // its own and are guaranteed to be different from shapes of any other T_* types. So + // the shape check that follows already covers checking for T_OBJECT. + self.push_insn(block, Insn::GuardType { val: recv, guard_type: types::HeapBasicObject, state }) + } + } + fn load_ivar(&mut self, block: BlockId, self_val: InsnId, recv_type: ProfiledType, id: ID, state: InsnId) -> InsnId { // Too-complex shapes use hash tables; rb_shape_get_iv_index doesn't support them. // Callers must filter these out before calling load_ivar. @@ -4500,7 +4501,7 @@ impl Function { self.count(block, Counter::getivar_fallback_too_complex); self.push_insn_id(block, insn_id); continue; } - let self_val = self.push_insn(block, Insn::GuardType { val: self_val, guard_type: types::HeapBasicObject, state }); + let self_val = self.load_ivar_guard_type(block, self_val, recv_type, state); let shape = self.load_shape(block, self_val); self.guard_shape(block, shape, recv_type.shape(), state); let replacement = self.load_ivar(block, self_val, recv_type, id, state); @@ -4529,7 +4530,7 @@ impl Function { self.count(block, Counter::definedivar_fallback_too_complex); self.push_insn_id(block, insn_id); continue; } - let self_val = self.push_insn(block, Insn::GuardType { val: self_val, guard_type: types::HeapBasicObject, state }); + let self_val = self.load_ivar_guard_type(block, self_val, recv_type, state); let shape = self.load_shape(block, self_val); self.guard_shape(block, shape, recv_type.shape(), state); let mut ivar_index: u16 = 0; @@ -4601,7 +4602,7 @@ impl Function { } // Fall through to emitting the ivar write } - let self_val = self.push_insn(block, Insn::GuardType { val: self_val, guard_type: types::HeapBasicObject, state }); + let self_val = self.load_ivar_guard_type(block, self_val, recv_type, state); let shape = self.load_shape(block, self_val); self.guard_shape(block, shape, recv_type.shape(), state); // Current shape contains this ivar @@ -5912,13 +5913,6 @@ impl Function { Some(DumpHIR::Debug) => println!("Optimized HIR:\n{:#?}", &self), None => {}, } - - if let Some(filename) = &get_option!(dump_hir_graphviz) { - use std::fs::OpenOptions; - use std::io::Write; - let mut file = OpenOptions::new().append(true).open(filename).unwrap(); - writeln!(file, "{}", FunctionGraphvizPrinter::new(self)).unwrap(); - } } pub fn dump_iongraph(&self, function_name: &str, passes: Vec) { @@ -6460,87 +6454,6 @@ impl<'a> std::fmt::Display for FunctionPrinter<'a> { } } -struct HtmlEncoder<'a, 'b> { - formatter: &'a mut std::fmt::Formatter<'b>, -} - -impl<'a, 'b> std::fmt::Write for HtmlEncoder<'a, 'b> { - fn write_str(&mut self, s: &str) -> std::fmt::Result { - for ch in s.chars() { - match ch { - '<' => self.formatter.write_str("<")?, - '>' => self.formatter.write_str(">")?, - '&' => self.formatter.write_str("&")?, - '"' => self.formatter.write_str(""")?, - '\'' => self.formatter.write_str("'")?, - _ => self.formatter.write_char(ch)?, - } - } - Ok(()) - } -} - -impl<'a> std::fmt::Display for FunctionGraphvizPrinter<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - macro_rules! write_encoded { - ($f:ident, $($arg:tt)*) => { - HtmlEncoder { formatter: $f }.write_fmt(format_args!($($arg)*)) - }; - } - use std::fmt::Write; - let fun = &self.fun; - let iseq_name = iseq_get_location(fun.iseq, 0); - write!(f, "digraph G {{ # ")?; - write_encoded!(f, "{iseq_name}")?; - writeln!(f)?; - writeln!(f, "node [shape=plaintext];")?; - writeln!(f, "mode=hier; overlap=false; splines=true;")?; - for block_id in fun.rpo() { - writeln!(f, r#" {block_id} [label=<"#)?; - write!(f, r#"")?; - for insn_id in &fun.blocks[block_id.0].insns { - let insn_id = fun.union_find.borrow().find_const(*insn_id); - let insn = fun.find(insn_id); - if matches!(insn, Insn::Snapshot {..}) { - continue; - } - write!(f, r#"")?; - } - writeln!(f, "
{block_id}("#)?; - if !fun.blocks[block_id.0].params.is_empty() { - let mut sep = ""; - for param in &fun.blocks[block_id.0].params { - write_encoded!(f, "{sep}{param}")?; - let insn_type = fun.type_of(*param); - if !insn_type.is_subtype(types::Empty) { - write_encoded!(f, ":{}", insn_type.print(&self.ptr_map))?; - } - sep = ", "; - } - } - let mut edges = vec![]; - writeln!(f, ") 
"#)?; - if insn.has_output() { - let insn_type = fun.type_of(insn_id); - if insn_type.is_subtype(types::Empty) { - write_encoded!(f, "{insn_id} = ")?; - } else { - write_encoded!(f, "{insn_id}:{} = ", insn_type.print(&self.ptr_map))?; - } - } - if let Insn::Jump(ref target) | Insn::IfTrue { ref target, .. } | Insn::IfFalse { ref target, .. } = insn { - edges.push((insn_id, target.target)); - } - write_encoded!(f, "{}", insn.print(&self.ptr_map, Some(fun.iseq)))?; - writeln!(f, " 
>];")?; - for (src, dst) in edges { - writeln!(f, " {block_id}:{src} -> {dst}:params:n;")?; - } - } - writeln!(f, "}}") - } -} - #[derive(Debug, Clone, PartialEq)] pub struct FrameState { pub iseq: IseqPtr, @@ -9274,131 +9187,3 @@ mod infer_tests { }); } } - -#[cfg(test)] -mod graphviz_tests { - use super::*; - use insta::assert_snapshot; - - #[track_caller] - fn hir_string(method: &str) -> String { - let iseq = crate::cruby::with_rubyvm(|| get_method_iseq("self", method)); - unsafe { crate::cruby::rb_zjit_profile_disable(iseq) }; - let mut function = iseq_to_hir(iseq).unwrap(); - function.optimize(); - function.validate().unwrap(); - format!("{}", FunctionGraphvizPrinter::new(&function)) - } - - #[test] - fn test_guard_fixnum_or_fixnum() { - eval(r#" - def test(x, y) = x | y - - test(1, 2) - "#); - assert_snapshot!(hir_string("test"), @r#" - digraph G { # test@<compiled>:2 - node [shape=plaintext]; - mode=hier; overlap=false; splines=true; - bb0 [label=< - - -
bb0() 
Entries bb1, bb2 
>]; - bb1 [label=< - - - - - - - -
bb1() 
EntryPoint interpreter 
v1:BasicObject = LoadSelf 
v2:CPtr = LoadSP 
v3:BasicObject = LoadField v2, :x@0x1000 
v4:BasicObject = LoadField v2, :y@0x1001 
Jump bb3(v1, v3, v4) 
>]; - bb1:v5 -> bb3:params:n; - bb2 [label=< - - - - - - -
bb2() 
EntryPoint JIT(0) 
v7:BasicObject = LoadArg :self@0 
v8:BasicObject = LoadArg :x@1 
v9:BasicObject = LoadArg :y@2 
Jump bb3(v7, v8, v9) 
>]; - bb2:v10 -> bb3:params:n; - bb3 [label=< - - - - - - - - -
bb3(v11:BasicObject, v12:BasicObject, v13:BasicObject) 
PatchPoint NoTracePoint 
PatchPoint MethodRedefined(Integer@0x1008, |@0x1010, cme:0x1018) 
v28:Fixnum = GuardType v12, Fixnum 
v29:Fixnum = GuardType v13, Fixnum 
v30:Fixnum = FixnumOr v28, v29 
CheckInterrupts 
Return v30 
>]; - } - "#); - } - - #[test] - fn test_multiple_blocks() { - eval(r#" - def test(c) - if c - 3 - else - 4 - end - end - - test(1) - test("x") - "#); - assert_snapshot!(hir_string("test"), @r#" - digraph G { # test@<compiled>:3 - node [shape=plaintext]; - mode=hier; overlap=false; splines=true; - bb0 [label=< - - -
bb0() 
Entries bb1, bb2 
>]; - bb1 [label=< - - - - - - -
bb1() 
EntryPoint interpreter 
v1:BasicObject = LoadSelf 
v2:CPtr = LoadSP 
v3:BasicObject = LoadField v2, :c@0x1000 
Jump bb3(v1, v3) 
>]; - bb1:v4 -> bb3:params:n; - bb2 [label=< - - - - - -
bb2() 
EntryPoint JIT(0) 
v6:BasicObject = LoadArg :self@0 
v7:BasicObject = LoadArg :c@1 
Jump bb3(v6, v7) 
>]; - bb2:v8 -> bb3:params:n; - bb3 [label=< - - - - - - - - - - - -
bb3(v9:BasicObject, v10:BasicObject) 
PatchPoint NoTracePoint 
CheckInterrupts 
v16:CBool = Test v10 
v17:Falsy = RefineType v10, Falsy 
IfFalse v16, bb4(v9, v17) 
v19:Truthy = RefineType v10, Truthy 
PatchPoint NoTracePoint 
v22:Fixnum[3] = Const Value(3) 
CheckInterrupts 
Return v22 
>]; - bb3:v18 -> bb4:params:n; - bb4 [label=< - - - - - -
bb4(v27:BasicObject, v28:Falsy) 
PatchPoint NoTracePoint 
v32:Fixnum[4] = Const Value(4) 
CheckInterrupts 
Return v32 
>]; - } - "#); - } -} diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 3cf9a8f3fa69e8..fde0ab0583d279 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -7346,7 +7346,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:Module = GuardType v6, Module v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) PatchPoint RootBoxOnly @@ -7381,7 +7381,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:Module = GuardType v6, Module v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) PatchPoint RootBoxOnly @@ -7414,7 +7414,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:Class = GuardType v6, Class v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) PatchPoint RootBoxOnly @@ -7449,7 +7449,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:Class = GuardType v6, Class v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) PatchPoint RootBoxOnly @@ -7520,7 +7520,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:TypedTData = GuardType v6, TypedTData v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) v20:RubyValue = LoadField v17, :_fields_obj@0x1002 @@ -7556,7 +7556,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:TypedTData = GuardType v6, TypedTData v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) v20:RubyValue = LoadField v17, :_fields_obj@0x1002 diff --git a/zjit/src/hir_type/gen_hir_type.rb b/zjit/src/hir_type/gen_hir_type.rb index 5934dcb2d0ea41..37919425cee51f 100644 --- a/zjit/src/hir_type/gen_hir_type.rb +++ b/zjit/src/hir_type/gen_hir_type.rb @@ -132,6 +132,11 @@ def final_type name, base: $object, c_name: nil true_exact = final_type "TrueClass", c_name: "rb_cTrueClass" false_exact = final_type "FalseClass", c_name: "rb_cFalseClass" +# Typed T_DATA objects (RTYPEDDATA_P). These have a distinct memory layout +# for field access (fields_obj at a fixed offset in RTypedData). These +# don't have a common class ancestor below BasicObject. +basic_object.subtype "TypedTData" + # Build the cvalue object universe. This is for C-level types that may be # passed around when calling into the Ruby VM or after some strength reduction # of HIR. diff --git a/zjit/src/hir_type/hir_type.inc.rs b/zjit/src/hir_type/hir_type.inc.rs index 99697deb82a23b..f37cd57b319445 100644 --- a/zjit/src/hir_type/hir_type.inc.rs +++ b/zjit/src/hir_type/hir_type.inc.rs @@ -4,7 +4,7 @@ mod bits { pub const Array: u64 = ArrayExact | ArraySubclass; pub const ArrayExact: u64 = 1u64 << 0; pub const ArraySubclass: u64 = 1u64 << 1; - pub const BasicObject: u64 = BasicObjectExact | BasicObjectSubclass | Object; + pub const BasicObject: u64 = BasicObjectExact | BasicObjectSubclass | Object | TypedTData; pub const BasicObjectExact: u64 = 1u64 << 2; pub const BasicObjectSubclass: u64 = 1u64 << 3; pub const Bignum: u64 = 1u64 << 4; @@ -73,20 +73,22 @@ mod bits { pub const Symbol: u64 = DynamicSymbol | StaticSymbol; pub const TrueClass: u64 = 1u64 << 43; pub const Truthy: u64 = BasicObject & !Falsy; - pub const Undef: u64 = 1u64 << 44; - pub const AllBitPatterns: [(&str, u64); 74] = [ + pub const TypedTData: u64 = 1u64 << 44; + pub const Undef: u64 = 1u64 << 45; + pub const AllBitPatterns: [(&str, u64); 75] = [ ("Any", Any), ("RubyValue", RubyValue), ("Immediate", Immediate), ("Undef", Undef), ("BasicObject", BasicObject), - ("Object", Object), ("NotNil", NotNil), ("Truthy", Truthy), + ("HeapBasicObject", HeapBasicObject), + ("TypedTData", TypedTData), + ("Object", Object), ("BuiltinExact", BuiltinExact), ("BoolExact", BoolExact), ("TrueClass", TrueClass), - ("HeapBasicObject", HeapBasicObject), ("HeapObject", HeapObject), ("String", String), ("Subclass", Subclass), @@ -150,7 +152,7 @@ mod bits { ("ArrayExact", ArrayExact), ("Empty", Empty), ]; - pub const NumTypeBits: u64 = 45; + pub const NumTypeBits: u64 = 46; } pub mod types { use super::*; @@ -227,6 +229,7 @@ pub mod types { pub const Symbol: Type = Type::from_bits(bits::Symbol); pub const TrueClass: Type = Type::from_bits(bits::TrueClass); pub const Truthy: Type = Type::from_bits(bits::Truthy); + pub const TypedTData: Type = Type::from_bits(bits::TypedTData); pub const Undef: Type = Type::from_bits(bits::Undef); pub const ExactBitsAndClass: [(u64, *const VALUE); 17] = [ (bits::ObjectExact, &raw const crate::cruby::rb_cObject), diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs index 1e6c0d2df7f8b8..0ebdc4be60a17a 100644 --- a/zjit/src/hir_type/mod.rs +++ b/zjit/src/hir_type/mod.rs @@ -1,6 +1,7 @@ //! High-level intermediate representation types. #![allow(non_upper_case_globals)] +use crate::cruby; use crate::cruby::{rb_block_param_proxy, Qfalse, Qnil, Qtrue, RUBY_T_ARRAY, RUBY_T_CLASS, RUBY_T_HASH, RUBY_T_MODULE, RUBY_T_STRING, VALUE}; use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cRange, rb_cModule, rb_zjit_singleton_class_p}; use crate::cruby::ClassRelationship; @@ -213,6 +214,7 @@ impl Type { else if val.class_of() == unsafe { rb_cSymbol } { bits::DynamicSymbol } else if let Some(bits) = Self::bits_from_exact_class(val.class_of()) { bits } else if let Some(bits) = Self::bits_from_subclass(val.class_of()) { bits } + else if val.typed_data_p() { bits::TypedTData } else { unreachable!("Class {} is not a subclass of BasicObject! Don't know what to do.", get_class_name(val.class_of())) @@ -429,6 +431,24 @@ impl Type { matches!(self.spec, Specialization::Object(_)) } + /// Find a `T_*` type that is exactly as wide as `self`. + pub fn builtin_type_equivalent(&self) -> Option { + if self.bit_equal(types::Array) { + Some(cruby::RUBY_T_ARRAY) + } else if self.bit_equal(types::Class) { + Some(cruby::RUBY_T_CLASS) + } else if self.bit_equal(types::Module) { + Some(cruby::RUBY_T_MODULE) + } else if self.bit_equal(types::String) { + Some(cruby::RUBY_T_STRING) + } else if self.bit_equal(types::Hash) { + Some(cruby::RUBY_T_HASH) + } else { + // Note that types::TypedTData is narrower than T_DATA, so not here. + None + } + } + fn is_builtin(class: VALUE) -> bool { types::ExactBitsAndClass .iter()