From 8824fd3bb5c79086b46d2205911d11a7f3fd81f9 Mon Sep 17 00:00:00 2001 From: Charles Oliver Nutter Date: Tue, 24 Mar 2026 11:51:15 -0500 Subject: [PATCH 1/4] [ruby/prism] Generate templated sources under main/java-templates This path avoids the sources getting wiped out during `mvn clean`, since they are not generated during the maven build. This patch also moves the generated WASM build under src/main/wasm since it is really a source file and not a test file. It will not be included in the built artifact. https://github.com/ruby/prism/commit/08dba29eb5 --- prism/templates/template.rb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/prism/templates/template.rb b/prism/templates/template.rb index 2e7d3b107f1a39..0fdeda561f3573 100755 --- a/prism/templates/template.rb +++ b/prism/templates/template.rb @@ -690,9 +690,9 @@ def locals "javascript/src/deserialize.js", "javascript/src/nodes.js", "javascript/src/visitor.js", - "java/api/target/generated-sources/java/org/ruby_lang/prism/Loader.java", - "java/api/target/generated-sources/java/org/ruby_lang/prism/Nodes.java", - "java/api/target/generated-sources/java/org/ruby_lang/prism/AbstractNodeVisitor.java", + "java/api/src/main/java-templates/org/ruby_lang/prism/Loader.java", + "java/api/src/main/java-templates/org/ruby_lang/prism/Nodes.java", + "java/api/src/main/java-templates/org/ruby_lang/prism/AbstractNodeVisitor.java", "lib/prism/compiler.rb", "lib/prism/dispatcher.rb", "lib/prism/dot_visitor.rb", From a8f3c34556bac709587ded4e0e4dad08932d5900 Mon Sep 17 00:00:00 2001 From: Alan Wu Date: Wed, 1 Apr 2026 10:13:31 -0400 Subject: [PATCH 2/4] ZJIT: Add missing guard on ivar access on T_{DATA,CLASS,MODULE} T_DATA, T_MODULE, and T_CLASS objects can share the exact same shape. The shape on these objects give an index off of the fields array to get at the ivar. When two objects share the same shape, but differ in the T_* builtin type, however, the way to get to the fields array differ. Previously, we did not guard the builtin type, so the guard allowed using say, loading `t_string[RCLASS_OFFSET_PRIME_FIELDS_OBJ]`. A classic type confusion situation that crashed. Guard the builtin type, in addition to the shape. Note that this is not necessary for T_OBJECTs since they never have the same shape as other builtin types. --- zjit/src/codegen.rs | 16 +++--- zjit/src/codegen_tests.rs | 91 +++++++++++++++++++++++++++++++ zjit/src/cruby.rs | 1 + zjit/src/hir.rs | 23 +++++++- zjit/src/hir/opt_tests.rs | 12 ++-- zjit/src/hir_type/gen_hir_type.rb | 5 ++ zjit/src/hir_type/hir_type.inc.rs | 15 +++-- zjit/src/hir_type/mod.rs | 20 +++++++ 8 files changed, 161 insertions(+), 22 deletions(-) diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 2cde7388fbe127..7394d3d96adf9a 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -2549,7 +2549,7 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard asm.cmp(klass, Opnd::Value(expected_class)); asm.jne(jit, side_exit); - } else if guard_type.is_subtype(types::String) { + } else if guard_type.is_subtype(types::TypedTData) { let side = side_exit(jit, state, GuardType(guard_type)); // Check special constant @@ -2560,13 +2560,15 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard asm.cmp(val, Qfalse.into()); asm.je(jit, side.clone()); + // Check the builtin type and RUBY_TYPED_FL_IS_TYPED_DATA with mask and compare let val = asm.load_mem(val); - let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); - let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); - asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64)); + let mask = RUBY_T_MASK.to_usize() | RUBY_TYPED_FL_IS_TYPED_DATA.to_usize(); + let expected = RUBY_T_DATA.to_usize() | RUBY_TYPED_FL_IS_TYPED_DATA.to_usize(); + let masked = asm.and(flags, mask.into()); + asm.cmp(masked, expected.into()); asm.jne(jit, side); - } else if guard_type.is_subtype(types::Array) { + } else if let Some(builtin_type) = guard_type.builtin_type_equivalent() { let side = side_exit(jit, state, GuardType(guard_type)); // Check special constant @@ -2577,11 +2579,11 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard asm.cmp(val, Qfalse.into()); asm.je(jit, side.clone()); + // Mask and check the builtin type let val = asm.load_mem(val); - let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); - asm.cmp(tag, Opnd::UImm(RUBY_T_ARRAY as u64)); + asm.cmp(tag, Opnd::UImm(builtin_type as u64)); asm.jne(jit, side); } else if guard_type.bit_equal(types::HeapBasicObject) { let side_exit = side_exit(jit, state, GuardType(guard_type)); diff --git a/zjit/src/codegen_tests.rs b/zjit/src/codegen_tests.rs index 40278d112e9a4b..7c8a3c758b1fbd 100644 --- a/zjit/src/codegen_tests.rs +++ b/zjit/src/codegen_tests.rs @@ -3623,6 +3623,97 @@ fn test_attr_accessor_getivar() { assert_snapshot!(assert_compiles("c = C.new; [test(c), test(c)]"), @"[4, 4]"); } +#[test] +fn test_getivar_t_data_then_string() { + // This is a regression test for a type confusion miscomp where + // we end up reading the fields object using an offset off of a + // string, assuming that it has a the same layout as a T_DATA object. + // At the time of writing the fields object of strings are stored + // in a global table, out-of-line of each string. + // The string and the thread end up sharing one shape ID. + set_call_threshold(2); + eval(r#" + module GetThousand + def test = @var1000 + end + class Thread + include GetThousand + end + class String + include GetThousand + end + OBJ = Thread.new { } + OBJ.join + STR = +'' + (0..1000).each do |i| + ivar_name = :"@var#{i}" + OBJ.instance_variable_set(ivar_name, i) + STR.instance_variable_set(ivar_name, i) + end + OBJ.test; OBJ.test # profile and compile for Thread (T_DATA) + "#); + assert_snapshot!(assert_compiles("[STR.test, STR.test]"), @"[1000, 1000]"); +} + +#[test] +fn test_getivar_t_object_then_string() { + // This test construct an object and a string that have the same set of ivars. + // They wouldn't share the same shape ID, though, and we rely on this fact in + // our guards. + set_call_threshold(2); + eval(r#" + module GetThousand + def test = @var1000 + end + class MyObject + include GetThousand + end + class String + include GetThousand + end + OBJ = MyObject.new + STR = +'' + (0..1000).each do |i| + ivar_name = :"@var#{i}" + OBJ.instance_variable_set(ivar_name, i) + STR.instance_variable_set(ivar_name, i) + end + OBJ.test; OBJ.test # profile and compile for MyObject + "#); + assert_snapshot!(assert_compiles("[STR.test, STR.test]"), @"[1000, 1000]"); +} + +#[test] +fn test_getivar_t_class_then_string() { + // This is a regression test for a type confusion miscomp where + // we end up reading the fields object using an offset off of a + // string, assuming that it has a the same layout as a T_CLASS object. + // At the time of writing the fields object of strings are stored + // in a global table, out-of-line of each string. + // The string and the class end up sharing one shape ID. + set_call_threshold(2); + eval(r#" + module GetThousand + def test = @var1000 + end + class MyClass + extend GetThousand + end + class String + include GetThousand + end + STR = +'' + (0..1000).each do |i| + ivar_name = :"@var#{i}" + MyClass.instance_variable_set(ivar_name, i) + STR.instance_variable_set(ivar_name, i) + end + p MyClass.test; p MyClass.test # profile and compile for MyClass + p STR.test + "#); + assert_snapshot!(assert_compiles("[STR.test, STR.test]"), @"[1000, 1000]"); +} + #[test] fn test_attr_accessor_setivar() { eval(" diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index ee9f7fa7e175ef..3c00275478523c 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -604,6 +604,7 @@ impl VALUE { unsafe { rb_jit_class_fields_embedded_p(self) } } + /// Typed `T_DATA` made from `TypedData_Make_Struct()` (e.g. Thread, ARGF) pub fn typed_data_p(self) -> bool { !self.special_const_p() && self.builtin_type() == RUBY_T_DATA && diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index e8db2b6af2649a..f4ed04921efb24 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -4431,6 +4431,23 @@ impl Function { } } + /// This puts a guard that establishes the preconditon for [Self::load_ivar] + fn load_ivar_guard_type(&mut self, block: BlockId, recv: InsnId, recv_type: ProfiledType, state: InsnId) -> InsnId { + if recv_type.class().is_subclass_of(unsafe { rb_cClass }) == ClassRelationship::Subclass { + // Check class first since `Class < Module` + self.push_insn(block, Insn::GuardType { val: recv, guard_type: types::Class, state }) + } else if recv_type.class().is_subclass_of(unsafe { rb_cModule }) == ClassRelationship::Subclass { + self.push_insn(block, Insn::GuardType { val: recv, guard_type: types::Module, state }) + } else if recv_type.flags().is_typed_data() { + self.push_insn(block, Insn::GuardType { val: recv, guard_type: types::TypedTData, state }) + } else { + // HeapBasicObject is wider than T_OBJECT, but shapes for T_OBJECTs are in a pool of + // its own and are guaranteed to be different from shapes of any other T_* types. So + // the shape check that follows already covers checking for T_OBJECT. + self.push_insn(block, Insn::GuardType { val: recv, guard_type: types::HeapBasicObject, state }) + } + } + fn load_ivar(&mut self, block: BlockId, self_val: InsnId, recv_type: ProfiledType, id: ID, state: InsnId) -> InsnId { // Too-complex shapes use hash tables; rb_shape_get_iv_index doesn't support them. // Callers must filter these out before calling load_ivar. @@ -4500,7 +4517,7 @@ impl Function { self.count(block, Counter::getivar_fallback_too_complex); self.push_insn_id(block, insn_id); continue; } - let self_val = self.push_insn(block, Insn::GuardType { val: self_val, guard_type: types::HeapBasicObject, state }); + let self_val = self.load_ivar_guard_type(block, self_val, recv_type, state); let shape = self.load_shape(block, self_val); self.guard_shape(block, shape, recv_type.shape(), state); let replacement = self.load_ivar(block, self_val, recv_type, id, state); @@ -4529,7 +4546,7 @@ impl Function { self.count(block, Counter::definedivar_fallback_too_complex); self.push_insn_id(block, insn_id); continue; } - let self_val = self.push_insn(block, Insn::GuardType { val: self_val, guard_type: types::HeapBasicObject, state }); + let self_val = self.load_ivar_guard_type(block, self_val, recv_type, state); let shape = self.load_shape(block, self_val); self.guard_shape(block, shape, recv_type.shape(), state); let mut ivar_index: u16 = 0; @@ -4601,7 +4618,7 @@ impl Function { } // Fall through to emitting the ivar write } - let self_val = self.push_insn(block, Insn::GuardType { val: self_val, guard_type: types::HeapBasicObject, state }); + let self_val = self.load_ivar_guard_type(block, self_val, recv_type, state); let shape = self.load_shape(block, self_val); self.guard_shape(block, shape, recv_type.shape(), state); // Current shape contains this ivar diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 3cf9a8f3fa69e8..fde0ab0583d279 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -7346,7 +7346,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:Module = GuardType v6, Module v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) PatchPoint RootBoxOnly @@ -7381,7 +7381,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:Module = GuardType v6, Module v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) PatchPoint RootBoxOnly @@ -7414,7 +7414,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:Class = GuardType v6, Class v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) PatchPoint RootBoxOnly @@ -7449,7 +7449,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:Class = GuardType v6, Class v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) PatchPoint RootBoxOnly @@ -7520,7 +7520,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:TypedTData = GuardType v6, TypedTData v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) v20:RubyValue = LoadField v17, :_fields_obj@0x1002 @@ -7556,7 +7556,7 @@ mod hir_opt_tests { Jump bb3(v4) bb3(v6:BasicObject): PatchPoint SingleRactorMode - v17:HeapBasicObject = GuardType v6, HeapBasicObject + v17:TypedTData = GuardType v6, TypedTData v18:CShape = LoadField v17, :_shape_id@0x1000 v19:CShape[0x1001] = GuardBitEquals v18, CShape(0x1001) v20:RubyValue = LoadField v17, :_fields_obj@0x1002 diff --git a/zjit/src/hir_type/gen_hir_type.rb b/zjit/src/hir_type/gen_hir_type.rb index 5934dcb2d0ea41..37919425cee51f 100644 --- a/zjit/src/hir_type/gen_hir_type.rb +++ b/zjit/src/hir_type/gen_hir_type.rb @@ -132,6 +132,11 @@ def final_type name, base: $object, c_name: nil true_exact = final_type "TrueClass", c_name: "rb_cTrueClass" false_exact = final_type "FalseClass", c_name: "rb_cFalseClass" +# Typed T_DATA objects (RTYPEDDATA_P). These have a distinct memory layout +# for field access (fields_obj at a fixed offset in RTypedData). These +# don't have a common class ancestor below BasicObject. +basic_object.subtype "TypedTData" + # Build the cvalue object universe. This is for C-level types that may be # passed around when calling into the Ruby VM or after some strength reduction # of HIR. diff --git a/zjit/src/hir_type/hir_type.inc.rs b/zjit/src/hir_type/hir_type.inc.rs index 99697deb82a23b..f37cd57b319445 100644 --- a/zjit/src/hir_type/hir_type.inc.rs +++ b/zjit/src/hir_type/hir_type.inc.rs @@ -4,7 +4,7 @@ mod bits { pub const Array: u64 = ArrayExact | ArraySubclass; pub const ArrayExact: u64 = 1u64 << 0; pub const ArraySubclass: u64 = 1u64 << 1; - pub const BasicObject: u64 = BasicObjectExact | BasicObjectSubclass | Object; + pub const BasicObject: u64 = BasicObjectExact | BasicObjectSubclass | Object | TypedTData; pub const BasicObjectExact: u64 = 1u64 << 2; pub const BasicObjectSubclass: u64 = 1u64 << 3; pub const Bignum: u64 = 1u64 << 4; @@ -73,20 +73,22 @@ mod bits { pub const Symbol: u64 = DynamicSymbol | StaticSymbol; pub const TrueClass: u64 = 1u64 << 43; pub const Truthy: u64 = BasicObject & !Falsy; - pub const Undef: u64 = 1u64 << 44; - pub const AllBitPatterns: [(&str, u64); 74] = [ + pub const TypedTData: u64 = 1u64 << 44; + pub const Undef: u64 = 1u64 << 45; + pub const AllBitPatterns: [(&str, u64); 75] = [ ("Any", Any), ("RubyValue", RubyValue), ("Immediate", Immediate), ("Undef", Undef), ("BasicObject", BasicObject), - ("Object", Object), ("NotNil", NotNil), ("Truthy", Truthy), + ("HeapBasicObject", HeapBasicObject), + ("TypedTData", TypedTData), + ("Object", Object), ("BuiltinExact", BuiltinExact), ("BoolExact", BoolExact), ("TrueClass", TrueClass), - ("HeapBasicObject", HeapBasicObject), ("HeapObject", HeapObject), ("String", String), ("Subclass", Subclass), @@ -150,7 +152,7 @@ mod bits { ("ArrayExact", ArrayExact), ("Empty", Empty), ]; - pub const NumTypeBits: u64 = 45; + pub const NumTypeBits: u64 = 46; } pub mod types { use super::*; @@ -227,6 +229,7 @@ pub mod types { pub const Symbol: Type = Type::from_bits(bits::Symbol); pub const TrueClass: Type = Type::from_bits(bits::TrueClass); pub const Truthy: Type = Type::from_bits(bits::Truthy); + pub const TypedTData: Type = Type::from_bits(bits::TypedTData); pub const Undef: Type = Type::from_bits(bits::Undef); pub const ExactBitsAndClass: [(u64, *const VALUE); 17] = [ (bits::ObjectExact, &raw const crate::cruby::rb_cObject), diff --git a/zjit/src/hir_type/mod.rs b/zjit/src/hir_type/mod.rs index 1e6c0d2df7f8b8..0ebdc4be60a17a 100644 --- a/zjit/src/hir_type/mod.rs +++ b/zjit/src/hir_type/mod.rs @@ -1,6 +1,7 @@ //! High-level intermediate representation types. #![allow(non_upper_case_globals)] +use crate::cruby; use crate::cruby::{rb_block_param_proxy, Qfalse, Qnil, Qtrue, RUBY_T_ARRAY, RUBY_T_CLASS, RUBY_T_HASH, RUBY_T_MODULE, RUBY_T_STRING, VALUE}; use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cRange, rb_cModule, rb_zjit_singleton_class_p}; use crate::cruby::ClassRelationship; @@ -213,6 +214,7 @@ impl Type { else if val.class_of() == unsafe { rb_cSymbol } { bits::DynamicSymbol } else if let Some(bits) = Self::bits_from_exact_class(val.class_of()) { bits } else if let Some(bits) = Self::bits_from_subclass(val.class_of()) { bits } + else if val.typed_data_p() { bits::TypedTData } else { unreachable!("Class {} is not a subclass of BasicObject! Don't know what to do.", get_class_name(val.class_of())) @@ -429,6 +431,24 @@ impl Type { matches!(self.spec, Specialization::Object(_)) } + /// Find a `T_*` type that is exactly as wide as `self`. + pub fn builtin_type_equivalent(&self) -> Option { + if self.bit_equal(types::Array) { + Some(cruby::RUBY_T_ARRAY) + } else if self.bit_equal(types::Class) { + Some(cruby::RUBY_T_CLASS) + } else if self.bit_equal(types::Module) { + Some(cruby::RUBY_T_MODULE) + } else if self.bit_equal(types::String) { + Some(cruby::RUBY_T_STRING) + } else if self.bit_equal(types::Hash) { + Some(cruby::RUBY_T_HASH) + } else { + // Note that types::TypedTData is narrower than T_DATA, so not here. + None + } + } + fn is_builtin(class: VALUE) -> bool { types::ExactBitsAndClass .iter() From 6e18d5f5e5bd8e9cd1498123f85b6c83f3f812aa Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 1 Apr 2026 13:08:12 -0400 Subject: [PATCH 3/4] ZJIT: Remove old unused graphviz code (#16630) We use iongraph now. --- zjit/src/hir.rs | 232 ------------------------------------------------ 1 file changed, 232 deletions(-) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index f4ed04921efb24..bc7be7498a3c81 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -2228,22 +2228,6 @@ impl<'a> FunctionPrinter<'a> { } } -/// Pretty printer for [`Function`]. -pub struct FunctionGraphvizPrinter<'a> { - fun: &'a Function, - ptr_map: PtrPrintMap, -} - -impl<'a> FunctionGraphvizPrinter<'a> { - pub fn new(fun: &'a Function) -> Self { - let mut ptr_map = PtrPrintMap::identity(); - if cfg!(test) { - ptr_map.map_ptrs = true; - } - Self { fun, ptr_map } - } -} - /// Union-Find (Disjoint-Set) is a data structure for managing disjoint sets that has an interface /// of two operations: /// @@ -5929,13 +5913,6 @@ impl Function { Some(DumpHIR::Debug) => println!("Optimized HIR:\n{:#?}", &self), None => {}, } - - if let Some(filename) = &get_option!(dump_hir_graphviz) { - use std::fs::OpenOptions; - use std::io::Write; - let mut file = OpenOptions::new().append(true).open(filename).unwrap(); - writeln!(file, "{}", FunctionGraphvizPrinter::new(self)).unwrap(); - } } pub fn dump_iongraph(&self, function_name: &str, passes: Vec) { @@ -6477,87 +6454,6 @@ impl<'a> std::fmt::Display for FunctionPrinter<'a> { } } -struct HtmlEncoder<'a, 'b> { - formatter: &'a mut std::fmt::Formatter<'b>, -} - -impl<'a, 'b> std::fmt::Write for HtmlEncoder<'a, 'b> { - fn write_str(&mut self, s: &str) -> std::fmt::Result { - for ch in s.chars() { - match ch { - '<' => self.formatter.write_str("<")?, - '>' => self.formatter.write_str(">")?, - '&' => self.formatter.write_str("&")?, - '"' => self.formatter.write_str(""")?, - '\'' => self.formatter.write_str("'")?, - _ => self.formatter.write_char(ch)?, - } - } - Ok(()) - } -} - -impl<'a> std::fmt::Display for FunctionGraphvizPrinter<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - macro_rules! write_encoded { - ($f:ident, $($arg:tt)*) => { - HtmlEncoder { formatter: $f }.write_fmt(format_args!($($arg)*)) - }; - } - use std::fmt::Write; - let fun = &self.fun; - let iseq_name = iseq_get_location(fun.iseq, 0); - write!(f, "digraph G {{ # ")?; - write_encoded!(f, "{iseq_name}")?; - writeln!(f)?; - writeln!(f, "node [shape=plaintext];")?; - writeln!(f, "mode=hier; overlap=false; splines=true;")?; - for block_id in fun.rpo() { - writeln!(f, r#" {block_id} [label=<"#)?; - write!(f, r#"")?; - for insn_id in &fun.blocks[block_id.0].insns { - let insn_id = fun.union_find.borrow().find_const(*insn_id); - let insn = fun.find(insn_id); - if matches!(insn, Insn::Snapshot {..}) { - continue; - } - write!(f, r#"")?; - } - writeln!(f, "
{block_id}("#)?; - if !fun.blocks[block_id.0].params.is_empty() { - let mut sep = ""; - for param in &fun.blocks[block_id.0].params { - write_encoded!(f, "{sep}{param}")?; - let insn_type = fun.type_of(*param); - if !insn_type.is_subtype(types::Empty) { - write_encoded!(f, ":{}", insn_type.print(&self.ptr_map))?; - } - sep = ", "; - } - } - let mut edges = vec![]; - writeln!(f, ") 
"#)?; - if insn.has_output() { - let insn_type = fun.type_of(insn_id); - if insn_type.is_subtype(types::Empty) { - write_encoded!(f, "{insn_id} = ")?; - } else { - write_encoded!(f, "{insn_id}:{} = ", insn_type.print(&self.ptr_map))?; - } - } - if let Insn::Jump(ref target) | Insn::IfTrue { ref target, .. } | Insn::IfFalse { ref target, .. } = insn { - edges.push((insn_id, target.target)); - } - write_encoded!(f, "{}", insn.print(&self.ptr_map, Some(fun.iseq)))?; - writeln!(f, " 
>];")?; - for (src, dst) in edges { - writeln!(f, " {block_id}:{src} -> {dst}:params:n;")?; - } - } - writeln!(f, "}}") - } -} - #[derive(Debug, Clone, PartialEq)] pub struct FrameState { pub iseq: IseqPtr, @@ -9291,131 +9187,3 @@ mod infer_tests { }); } } - -#[cfg(test)] -mod graphviz_tests { - use super::*; - use insta::assert_snapshot; - - #[track_caller] - fn hir_string(method: &str) -> String { - let iseq = crate::cruby::with_rubyvm(|| get_method_iseq("self", method)); - unsafe { crate::cruby::rb_zjit_profile_disable(iseq) }; - let mut function = iseq_to_hir(iseq).unwrap(); - function.optimize(); - function.validate().unwrap(); - format!("{}", FunctionGraphvizPrinter::new(&function)) - } - - #[test] - fn test_guard_fixnum_or_fixnum() { - eval(r#" - def test(x, y) = x | y - - test(1, 2) - "#); - assert_snapshot!(hir_string("test"), @r#" - digraph G { # test@<compiled>:2 - node [shape=plaintext]; - mode=hier; overlap=false; splines=true; - bb0 [label=< - - -
bb0() 
Entries bb1, bb2 
>]; - bb1 [label=< - - - - - - - -
bb1() 
EntryPoint interpreter 
v1:BasicObject = LoadSelf 
v2:CPtr = LoadSP 
v3:BasicObject = LoadField v2, :x@0x1000 
v4:BasicObject = LoadField v2, :y@0x1001 
Jump bb3(v1, v3, v4) 
>]; - bb1:v5 -> bb3:params:n; - bb2 [label=< - - - - - - -
bb2() 
EntryPoint JIT(0) 
v7:BasicObject = LoadArg :self@0 
v8:BasicObject = LoadArg :x@1 
v9:BasicObject = LoadArg :y@2 
Jump bb3(v7, v8, v9) 
>]; - bb2:v10 -> bb3:params:n; - bb3 [label=< - - - - - - - - -
bb3(v11:BasicObject, v12:BasicObject, v13:BasicObject) 
PatchPoint NoTracePoint 
PatchPoint MethodRedefined(Integer@0x1008, |@0x1010, cme:0x1018) 
v28:Fixnum = GuardType v12, Fixnum 
v29:Fixnum = GuardType v13, Fixnum 
v30:Fixnum = FixnumOr v28, v29 
CheckInterrupts 
Return v30 
>]; - } - "#); - } - - #[test] - fn test_multiple_blocks() { - eval(r#" - def test(c) - if c - 3 - else - 4 - end - end - - test(1) - test("x") - "#); - assert_snapshot!(hir_string("test"), @r#" - digraph G { # test@<compiled>:3 - node [shape=plaintext]; - mode=hier; overlap=false; splines=true; - bb0 [label=< - - -
bb0() 
Entries bb1, bb2 
>]; - bb1 [label=< - - - - - - -
bb1() 
EntryPoint interpreter 
v1:BasicObject = LoadSelf 
v2:CPtr = LoadSP 
v3:BasicObject = LoadField v2, :c@0x1000 
Jump bb3(v1, v3) 
>]; - bb1:v4 -> bb3:params:n; - bb2 [label=< - - - - - -
bb2() 
EntryPoint JIT(0) 
v6:BasicObject = LoadArg :self@0 
v7:BasicObject = LoadArg :c@1 
Jump bb3(v6, v7) 
>]; - bb2:v8 -> bb3:params:n; - bb3 [label=< - - - - - - - - - - - -
bb3(v9:BasicObject, v10:BasicObject) 
PatchPoint NoTracePoint 
CheckInterrupts 
v16:CBool = Test v10 
v17:Falsy = RefineType v10, Falsy 
IfFalse v16, bb4(v9, v17) 
v19:Truthy = RefineType v10, Truthy 
PatchPoint NoTracePoint 
v22:Fixnum[3] = Const Value(3) 
CheckInterrupts 
Return v22 
>]; - bb3:v18 -> bb4:params:n; - bb4 [label=< - - - - - -
bb4(v27:BasicObject, v28:Falsy) 
PatchPoint NoTracePoint 
v32:Fixnum[4] = Const Value(4) 
CheckInterrupts 
Return v32 
>]; - } - "#); - } -} From 1bb1f6c42b380f1c4bc660a400d3204571aaa955 Mon Sep 17 00:00:00 2001 From: Earlopain <14981592+Earlopain@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:57:00 +0200 Subject: [PATCH 4/4] [ruby/prism] Emit `on_kw` for ripper `yard` uses it Start checking against the ordering of the events and also their location. I didn't fix any of the preexisting failures and just ignored them. Some are easy to fix, others look like particularities of ripper that I don't think anyone would rely on. https://github.com/ruby/prism/commit/4cba29d282 --- lib/prism/translation/ripper.rb | 241 +++++++++++++++++++++++++++++++- test/prism/newline_test.rb | 1 + test/prism/ruby/ripper_test.rb | 108 +++++++++++++- 3 files changed, 340 insertions(+), 10 deletions(-) diff --git a/lib/prism/translation/ripper.rb b/lib/prism/translation/ripper.rb index bbfa1f4d05d175..2f66bab97ee5b8 100644 --- a/lib/prism/translation/ripper.rb +++ b/lib/prism/translation/ripper.rb @@ -606,6 +606,9 @@ def parse # alias foo bar # ^^^^^^^^^^^^^ def visit_alias_method_node(node) + bounds(node.keyword_loc) + on_kw("alias") + new_name = visit(node.new_name) old_name = visit(node.old_name) @@ -616,6 +619,9 @@ def visit_alias_method_node(node) # alias $foo $bar # ^^^^^^^^^^^^^^^ def visit_alias_global_variable_node(node) + bounds(node.keyword_loc) + on_kw("alias") + new_name = visit_alias_global_variable_node_value(node.new_name) old_name = visit_alias_global_variable_node_value(node.old_name) @@ -661,6 +667,10 @@ def visit_alternation_pattern_node(node) # ^^^^^^^ def visit_and_node(node) left = visit(node.left) + if node.operator == "and" + bounds(node.operator_loc) + on_kw("and") + end right = visit(node.right) bounds(node.location) @@ -887,8 +897,18 @@ def visit_back_reference_read_node(node) # begin end # ^^^^^^^^^ def visit_begin_node(node) + if node.begin_keyword_loc + bounds(node.begin_keyword_loc) + on_kw("begin") + end + clauses = visit_begin_node_clauses(node.begin_keyword_loc, node, false) + if node.end_keyword_loc + bounds(node.end_keyword_loc) + on_kw("end") + end + bounds(node.location) on_begin(clauses) end @@ -909,6 +929,9 @@ def visit_begin_node(node) rescue_clause = visit(node.rescue_clause) else_clause = unless (else_clause_node = node.else_clause).nil? + bounds(else_clause_node.else_keyword_loc) + on_kw("else") + else_statements = if else_clause_node.statements.nil? [nil] @@ -966,6 +989,11 @@ def visit_block_node(node) braces = node.opening == "{" parameters = visit(node.parameters) + unless braces + bounds(node.opening_loc) + on_kw("do") + end + body = case node.body when nil @@ -987,6 +1015,11 @@ def visit_block_node(node) raise end + unless braces + bounds(node.closing_loc) + on_kw("end") + end + if braces bounds(node.location) on_brace_block(parameters, body) @@ -1037,6 +1070,9 @@ def visit_block_parameters_node(node) # break foo # ^^^^^^^^^ def visit_break_node(node) + bounds(node.keyword_loc) + on_kw("break") + if node.arguments.nil? bounds(node.location) on_break(on_args_new) @@ -1103,6 +1139,9 @@ def visit_call_node(node) on_unary(node.name, receiver) when :! if node.message == "not" + bounds(node.message_loc) + on_kw("not") + receiver = if !node.receiver.is_a?(ParenthesesNode) || !node.receiver.body.nil? visit(node.receiver) @@ -1347,10 +1386,21 @@ def visit_capture_pattern_node(node) # case foo; when bar; end # ^^^^^^^^^^^^^^^^^^^^^^^ def visit_case_node(node) + bounds(node.case_keyword_loc) + on_kw("case") + predicate = visit(node.predicate) + visited_conditions = node.conditions.map { |condition| visit(condition) } + visited_else_clause = visit(node.else_clause) + + if !node.else_clause + bounds(node.end_keyword_loc) + on_kw("end") + end + clauses = - node.conditions.reverse_each.inject(visit(node.else_clause)) do |current, condition| - on_when(*visit(condition), current) + visited_conditions.reverse_each.inject(visited_else_clause) do |current, condition| + on_when(*condition, current) end bounds(node.location) @@ -1360,10 +1410,23 @@ def visit_case_node(node) # case foo; in bar; end # ^^^^^^^^^^^^^^^^^^^^^ def visit_case_match_node(node) + bounds(node.case_keyword_loc) + on_kw("case") + predicate = visit(node.predicate) + visited_conditions = node.conditions.map do | condition| + visit(condition) + end + visited_else_clause = visit(node.else_clause) + + if !node.else_clause + bounds(node.end_keyword_loc) + on_kw("end") + end + clauses = - node.conditions.reverse_each.inject(visit(node.else_clause)) do |current, condition| - on_in(*visit(condition), current) + visited_conditions.reverse_each.inject(visited_else_clause) do |current, condition| + on_in(*condition, current) end bounds(node.location) @@ -1373,6 +1436,9 @@ def visit_case_match_node(node) # class Foo; end # ^^^^^^^^^^^^^^ def visit_class_node(node) + bounds(node.class_keyword_loc) + on_kw("class") + constant_path = if node.constant_path.is_a?(ConstantReadNode) bounds(node.constant_path.location) @@ -1384,6 +1450,9 @@ def visit_class_node(node) superclass = visit(node.superclass) bodystmt = visit_body_node(node.superclass&.location || node.constant_path.location, node.body, node.superclass.nil?) + bounds(node.end_keyword_loc) + on_kw("end") + bounds(node.location) on_class(constant_path, superclass, bodystmt) end @@ -1631,6 +1700,9 @@ def visit_constant_path_target_node(node) # def self.foo; end # ^^^^^^^^^^^^^^^^^ def visit_def_node(node) + bounds(node.def_keyword_loc) + on_kw("def") + receiver = visit(node.receiver) operator = if !node.operator_loc.nil? @@ -1664,6 +1736,11 @@ def visit_def_node(node) on_bodystmt(body, nil, nil, nil) end + if node.end_keyword_loc + bounds(node.end_keyword_loc) + on_kw("end") + end + bounds(node.location) if receiver on_defs(receiver, operator, name, parameters, bodystmt) @@ -1678,6 +1755,9 @@ def visit_def_node(node) # defined?(a) # ^^^^^^^^^^^ def visit_defined_node(node) + bounds(node.keyword_loc) + on_kw("defined?") + expression = visit(node.value) # Very weird circumstances here where something like: @@ -1700,6 +1780,9 @@ def visit_defined_node(node) # if foo then bar else baz end # ^^^^^^^^^^^^ def visit_else_node(node) + bounds(node.else_keyword_loc) + on_kw("else") + statements = if node.statements.nil? [nil] @@ -1709,8 +1792,12 @@ def visit_else_node(node) body end + else_statements = visit_statements_node_body(statements) + + bounds(node.end_keyword_loc) + on_kw("end") bounds(node.location) - on_else(visit_statements_node_body(statements)) + on_else(else_statements) end # "foo #{bar}" @@ -1748,6 +1835,9 @@ def visit_embedded_variable_node(node) # Visit an EnsureNode node. def visit_ensure_node(node) + bounds(node.ensure_keyword_loc) + on_kw("ensure") + statements = if node.statements.nil? [nil] @@ -1818,8 +1908,18 @@ def visit_float_node(node) # for foo in bar do end # ^^^^^^^^^^^^^^^^^^^^^ def visit_for_node(node) + bounds(node.for_keyword_loc) + on_kw("for") + index = visit(node.index) + bounds(node.in_keyword_loc) + on_kw("in") + collection = visit(node.collection) + if node.do_keyword_loc + bounds(node.do_keyword_loc) + on_kw("do") + end statements = if node.statements.nil? bounds(node.location) @@ -1828,6 +1928,9 @@ def visit_for_node(node) visit(node.statements) end + bounds(node.end_keyword_loc) + on_kw("end") + bounds(node.location) on_for(index, collection, statements) end @@ -1852,6 +1955,9 @@ def visit_forwarding_parameter_node(node) # super {} # ^^^^^^^^ def visit_forwarding_super_node(node) + bounds(node.keyword_loc) + on_kw("super") + if node.block.nil? bounds(node.location) on_zsuper @@ -2001,7 +2107,13 @@ def visit_if_node(node) bounds(node.location) on_ifop(predicate, truthy, falsy) elsif node.statements.nil? || (node.predicate.location.start_offset < node.statements.location.start_offset) + bounds(node.if_keyword_loc) + on_kw(node.if_keyword) predicate = visit(node.predicate) + if node.then_keyword_loc && node.then_keyword != "?" + bounds(node.then_keyword_loc) + on_kw("then") + end statements = if node.statements.nil? bounds(node.location) @@ -2011,6 +2123,11 @@ def visit_if_node(node) end subsequent = visit(node.subsequent) + if node.end_keyword_loc && !node.subsequent + bounds(node.end_keyword_loc) + on_kw("end") + end + bounds(node.location) if node.if_keyword == "if" on_if(predicate, statements, subsequent) @@ -2019,6 +2136,8 @@ def visit_if_node(node) end else statements = visit(node.statements.body.first) + bounds(node.if_keyword_loc) + on_kw(node.if_keyword) predicate = visit(node.predicate) bounds(node.location) @@ -2050,7 +2169,14 @@ def visit_in_node(node) # This is a special case where we're not going to call on_in directly # because we don't have access to the subsequent. Instead, we'll return # the component parts and let the parent node handle it. + bounds(node.in_loc) + on_kw("in") + pattern = visit_pattern_node(node.pattern) + if node.then_loc + bounds(node.then_loc) + on_kw("then") + end statements = if node.statements.nil? bounds(node.location) @@ -2386,6 +2512,11 @@ def visit_lambda_node(node) on_tlambeg(node.opening) end + unless braces + bounds(node.opening_loc) + on_kw("do") + end + body = case node.body when nil @@ -2407,6 +2538,11 @@ def visit_lambda_node(node) raise end + unless braces + bounds(node.closing_loc) + on_kw("end") + end + bounds(node.location) on_lambda(parameters, body) end @@ -2497,6 +2633,8 @@ def visit_match_last_line_node(node) # ^^^^^^^^^^ def visit_match_predicate_node(node) value = visit(node.value) + bounds(node.operator_loc) + on_kw("in") pattern = on_in(visit_pattern_node(node.pattern), nil, nil) on_case(value, pattern) @@ -2526,6 +2664,9 @@ def visit_error_recovery_node(node) # module Foo; end # ^^^^^^^^^^^^^^^ def visit_module_node(node) + bounds(node.module_keyword_loc) + on_kw("module") + constant_path = if node.constant_path.is_a?(ConstantReadNode) bounds(node.constant_path.location) @@ -2536,6 +2677,9 @@ def visit_module_node(node) bodystmt = visit_body_node(node.constant_path.location, node.body, true) + bounds(node.end_keyword_loc) + on_kw("end") + bounds(node.location) on_module(constant_path, bodystmt) end @@ -2617,6 +2761,9 @@ def visit_multi_write_node(node) # next foo # ^^^^^^^^ def visit_next_node(node) + bounds(node.keyword_loc) + on_kw("next") + if node.arguments.nil? bounds(node.location) on_next(on_args_new) @@ -2638,6 +2785,8 @@ def visit_nil_node(node) # def foo(&nil); end # ^^^^ def visit_no_block_parameter_node(node) + bounds(node.keyword_loc) + on_kw("nil") bounds(node.location) on_blockarg(:nil) end @@ -2645,6 +2794,8 @@ def visit_no_block_parameter_node(node) # def foo(**nil); end # ^^^^^ def visit_no_keywords_parameter_node(node) + bounds(node.keyword_loc) + on_kw("nil") bounds(node.location) on_nokw_param(nil) @@ -2687,6 +2838,10 @@ def visit_optional_parameter_node(node) # ^^^^^^ def visit_or_node(node) left = visit(node.left) + if node.operator == "or" + bounds(node.operator_loc) + on_kw("or") + end right = visit(node.right) bounds(node.location) @@ -2752,6 +2907,9 @@ def visit_pinned_variable_node(node) # END {} # ^^^^^^ def visit_post_execution_node(node) + bounds(node.keyword_loc) + on_kw("END") + statements = if node.statements.nil? bounds(node.location) @@ -2767,6 +2925,9 @@ def visit_post_execution_node(node) # BEGIN {} # ^^^^^^^^ def visit_pre_execution_node(node) + bounds(node.keyword_loc) + on_kw("BEGIN") + statements = if node.statements.nil? bounds(node.location) @@ -2813,6 +2974,7 @@ def visit_rational_node(node) # ^^^^ def visit_redo_node(node) bounds(node.location) + on_kw("redo") on_redo end @@ -2855,6 +3017,9 @@ def visit_required_parameter_node(node) # foo rescue bar # ^^^^^^^^^^^^^^ def visit_rescue_modifier_node(node) + bounds(node.keyword_loc) + on_kw("rescue") + expression = visit_write_value(node.expression) rescue_expression = visit(node.rescue_expression) @@ -2865,6 +3030,9 @@ def visit_rescue_modifier_node(node) # begin; rescue; end # ^^^^^^^ def visit_rescue_node(node) + bounds(node.keyword_loc) + on_kw("rescue") + exceptions = case node.exceptions.length when 0 @@ -2936,6 +3104,7 @@ def visit_rest_parameter_node(node) # ^^^^^ def visit_retry_node(node) bounds(node.location) + on_kw("retry") on_retry end @@ -2945,6 +3114,9 @@ def visit_retry_node(node) # return 1 # ^^^^^^^^ def visit_return_node(node) + bounds(node.keyword_loc) + on_kw("return") + if node.arguments.nil? bounds(node.location) on_return0 @@ -2971,9 +3143,15 @@ def visit_shareable_constant_node(node) # class << self; end # ^^^^^^^^^^^^^^^^^^ def visit_singleton_class_node(node) + bounds(node.class_keyword_loc) + on_kw("class") + expression = visit(node.expression) bodystmt = visit_body_node(node.body&.location || node.end_keyword_loc, node.body) + bounds(node.end_keyword_loc) + on_kw("end") + bounds(node.location) on_sclass(expression, bodystmt) end @@ -3180,6 +3358,9 @@ def visit_string_node(node) # super(foo) # ^^^^^^^^^^ def visit_super_node(node) + bounds(node.keyword_loc) + on_kw("super") + arguments, block, has_ripper_block = visit_call_node_arguments(node.arguments, node.block, trailing_comma?(node.arguments&.location || node.location, node.rparen_loc || node.location)) if !node.lparen_loc.nil? @@ -3233,6 +3414,9 @@ def visit_true_node(node) # undef foo # ^^^^^^^^^ def visit_undef_node(node) + bounds(node.keyword_loc) + on_kw("undef") + names = visit_all(node.names) bounds(node.location) @@ -3246,7 +3430,13 @@ def visit_undef_node(node) # ^^^^^^^^^^^^^^ def visit_unless_node(node) if node.statements.nil? || (node.predicate.location.start_offset < node.statements.location.start_offset) + bounds(node.keyword_loc) + on_kw("unless") predicate = visit(node.predicate) + if node.then_keyword_loc + bounds(node.then_keyword_loc) + on_kw("then") + end statements = if node.statements.nil? bounds(node.location) @@ -3256,10 +3446,17 @@ def visit_unless_node(node) end else_clause = visit(node.else_clause) + if node.end_keyword_loc && !node.else_clause + bounds(node.end_keyword_loc) + on_kw("end") + end + bounds(node.location) on_unless(predicate, statements, else_clause) else statements = visit(node.statements.body.first) + bounds(node.keyword_loc) + on_kw("unless") predicate = visit(node.predicate) bounds(node.location) @@ -3273,7 +3470,14 @@ def visit_unless_node(node) # bar until foo # ^^^^^^^^^^^^^ def visit_until_node(node) + bounds(node.keyword_loc) + on_kw("until") + if node.statements.nil? || (node.predicate.location.start_offset < node.statements.location.start_offset) + if node.do_keyword_loc + bounds(node.do_keyword_loc) + on_kw("do") + end predicate = visit(node.predicate) statements = if node.statements.nil? @@ -3283,6 +3487,11 @@ def visit_until_node(node) visit(node.statements) end + if node.closing_loc + bounds(node.closing_loc) + on_kw("end") + end + bounds(node.location) on_until(predicate, statements) else @@ -3300,7 +3509,14 @@ def visit_when_node(node) # This is a special case where we're not going to call on_when directly # because we don't have access to the subsequent. Instead, we'll return # the component parts and let the parent node handle it. + bounds(node.keyword_loc) + on_kw("when") + conditions = visit_arguments(node.conditions) + if node.then_keyword_loc + bounds(node.then_keyword_loc) + on_kw("then") + end statements = if node.statements.nil? bounds(node.location) @@ -3319,7 +3535,17 @@ def visit_when_node(node) # ^^^^^^^^^^^^^ def visit_while_node(node) if node.statements.nil? || (node.predicate.location.start_offset < node.statements.location.start_offset) + bounds(node.keyword_loc) + on_kw("while") + if node.do_keyword_loc + bounds(node.do_keyword_loc) + on_kw("do") + end predicate = visit(node.predicate) + if node.closing_loc + bounds(node.closing_loc) + on_kw("end") + end statements = if node.statements.nil? bounds(node.location) @@ -3332,6 +3558,8 @@ def visit_while_node(node) on_while(predicate, statements) else statements = visit(node.statements.body.first) + bounds(node.keyword_loc) + on_kw("while") predicate = visit(node.predicate) bounds(node.location) @@ -3367,6 +3595,9 @@ def visit_x_string_node(node) # yield 1 # ^^^^^^^ def visit_yield_node(node) + bounds(node.keyword_loc) + on_kw("yield") + if node.arguments.nil? && node.lparen_loc.nil? bounds(node.location) on_yield0 diff --git a/test/prism/newline_test.rb b/test/prism/newline_test.rb index c8914b57dcfac8..97e698202df748 100644 --- a/test/prism/newline_test.rb +++ b/test/prism/newline_test.rb @@ -20,6 +20,7 @@ class NewlineTest < TestCase ruby/find_fixtures.rb ruby/find_test.rb ruby/parser_test.rb + ruby/ripper_test.rb ruby/ruby_parser_test.rb ] diff --git a/test/prism/ruby/ripper_test.rb b/test/prism/ruby/ripper_test.rb index 7274454e1b44ed..1d20bceb40d6dc 100644 --- a/test/prism/ruby/ripper_test.rb +++ b/test/prism/ruby/ripper_test.rb @@ -136,8 +136,93 @@ def test_lex_ignored_missing_heredoc_end end end - UNSUPPORTED_EVENTS = %i[comma ignored_nl kw label_end lbrace lbracket lparen nl op rbrace rbracket rparen semicolon sp words_sep ignored_sp] + # Events that are currently not emitted + UNSUPPORTED_EVENTS = %i[comma ignored_nl label_end lbrace lbracket lparen nl op rbrace rbracket rparen semicolon sp words_sep ignored_sp] SUPPORTED_EVENTS = Translation::Ripper::EVENTS - UNSUPPORTED_EVENTS + # Events that assert against their line/column + CHECK_LOCATION_EVENTS = %i[kw] + IGNORE_FOR_SORT_EVENTS = %i[ + stmts_new stmts_add bodystmt void_stmt + args_new args_add args_add_star args_add_block arg_paren method_add_arg + mlhs_new mlhs_add_star + word_new words_new symbols_new qwords_new qsymbols_new xstring_new regexp_new + words_add symbols_add qwords_add qsymbols_add + regexp_end tstring_end heredoc_end + call command fcall vcall + field aref_field var_field var_ref block_var ident params + string_content heredoc_dedent unary binary dyna_symbol + comment magic_comment embdoc embdoc_beg embdoc_end arg_ambiguous + ] + SORT_IGNORE = { + aref: [ + "blocks.txt", + "command_method_call.txt", + "whitequark/ruby_bug_13547.txt", + ], + assoc_new: [ + "case_in_hash_key.txt", + "whitequark/parser_bug_525.txt", + "whitequark/ruby_bug_11380.txt", + ], + bare_assoc_hash: [ + "case_in_hash_key.txt", + "method_calls.txt", + "whitequark/parser_bug_525.txt", + "whitequark/ruby_bug_11380.txt", + ], + brace_block: [ + "super.txt", + "unparser/corpus/literal/super.txt" + ], + command_call: [ + "blocks.txt", + "case_in_hash_key.txt", + "seattlerb/block_call_dot_op2_cmd_args_do_block.txt", + "seattlerb/block_call_operation_colon.txt", + "seattlerb/block_call_operation_dot.txt", + ], + const_path_field: [ + "seattlerb/const_2_op_asgn_or2.txt", + "seattlerb/const_op_asgn_or.txt", + "whitequark/const_op_asgn.txt", + ], + const_path_ref: ["unparser/corpus/literal/defs.txt"], + do_block: ["whitequark/super_block.txt"], + embexpr_end: ["seattlerb/str_interp_ternary_or_label.txt"], + rest_param: ["whitequark/send_lambda.txt"], + top_const_field: [ + "seattlerb/const_3_op_asgn_or.txt", + "seattlerb/const_op_asgn_and1.txt", + "seattlerb/const_op_asgn_and2.txt", + "whitequark/const_op_asgn.txt", + ], + mlhs_paren: ["unparser/corpus/literal/for.txt"], + mlhs_add: [ + "whitequark/for_mlhs.txt", + ], + kw: [ + "defined.txt", + "for.txt", + "seattlerb/block_kw__required.txt", + "seattlerb/case_in_42.txt", + "seattlerb/case_in_67.txt", + "seattlerb/case_in_86_2.txt", + "seattlerb/case_in_86.txt", + "seattlerb/case_in_hash_pat_paren_true.txt", + "seattlerb/flip2_env_lvar.txt", + "unless.txt", + "unparser/corpus/semantic/and.txt", + "whitequark/class.txt", + "whitequark/find_pattern.txt", + "whitequark/pattern_matching_hash.txt", + "whitequark/pattern_matching_implicit_array_match.txt", + "whitequark/pattern_matching_ranges.txt", + "whitequark/super_block.txt", + "write_command_operator.txt", + ], + } + SORT_IGNORE.default = [] + SORT_EVENTS = SUPPORTED_EVENTS - IGNORE_FOR_SORT_EVENTS module Events attr_reader :events @@ -147,9 +232,20 @@ def initialize(...) @events = [] end + def sorted_events + @events.select do |e,| + next false if e == :kw && @events.any? { |e,| e == :if_mod || e == :while_mod || e == :until_mod || e == :rescue || e == :rescue_mod || e == :while || e == :ensure } + SORT_EVENTS.include?(e) && !SORT_IGNORE[e].include?(filename) + end + end + SUPPORTED_EVENTS.each do |event| define_method(:"on_#{event}") do |*args| - @events << [event, *args.map(&:to_s)] + if CHECK_LOCATION_EVENTS.include?(event) + @events << [event, lineno, column, *args.map(&:to_s)] + else + @events << [event, *args.map(&:to_s)] + end super(*args) end end @@ -177,12 +273,14 @@ class ObjectEvents < Translation::Ripper object_events = ObjectEvents.new(source) assert_nothing_raised { object_events.parse } - ripper = RipperEvents.new(source) - prism = PrismEvents.new(source) + ripper = RipperEvents.new(source, fixture.path) + prism = PrismEvents.new(source, fixture.path) ripper.parse prism.parse - # This makes sure that the content is the same. Ordering is not correct for now. + # Check that the same events are emitted, regardless of order assert_equal(ripper.events.sort, prism.events.sort) + # Check a subset of events against the correct order + assert_equal(ripper.sorted_events, prism.sorted_events) end end