Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/struct_ops_simple.ks
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ impl minimal_congestion_control {
return 16
}

fn undo_cwnd(sk: *u8) -> u32 {
return ssthresh(sk)
}

fn cong_avoid(sk: *u8, ack: u32, acked: u32) -> void {
// Minimal TCP congestion avoidance implementation
// In a real implementation, this would adjust the congestion window
Expand Down
165 changes: 108 additions & 57 deletions src/ir_generator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ type ir_context = {
map_origin_variables: (string, (string * ir_value * (ir_value_desc * ir_type))) Hashtbl.t; (* var_name -> (map_name, key, underlying_info) *)
(* Track inferred variable types for proper lookups *)
variable_types: (string, ir_type) Hashtbl.t; (* var_name -> ir_type *)
mutable current_program_type: program_type option;
}

(** Create new IR generation context *)
Expand All @@ -91,6 +92,7 @@ let create_context ?(global_variables = []) ?(helper_functions = []) symbol_tabl
tbl);
map_origin_variables = Hashtbl.create 32;
variable_types = Hashtbl.create 32;
current_program_type = None;
helper_functions = (let tbl = Hashtbl.create 16 in
List.iter (fun helper_name -> Hashtbl.add tbl helper_name ()) helper_functions;
tbl);
Expand Down Expand Up @@ -349,6 +351,85 @@ let extract_struct_ops_kernel_name attributes =
| _ -> acc
) "" attributes

let ast_struct_has_field ast struct_name field_name =
List.exists (function
| Ast.StructDecl struct_def when struct_def.Ast.struct_name = struct_name ->
List.exists (fun (name, _) -> name = field_name) struct_def.Ast.struct_fields
| _ -> false
) ast

let impl_block_has_static_field impl_block field_name =
List.exists (function
| Ast.ImplStaticField (name, _) when name = field_name -> true
| _ -> false
) impl_block.Ast.impl_items

let normalize_struct_ops_instance_name name =
let buffer = Buffer.create (String.length name * 2) in
let is_uppercase ch = ch >= 'A' && ch <= 'Z' in
let is_lowercase ch = ch >= 'a' && ch <= 'z' in
let is_digit ch = ch >= '0' && ch <= '9' in
let add_separator_if_needed idx ch =
if idx > 0 && is_uppercase ch then
let prev = name.[idx - 1] in
let next_is_lowercase = idx + 1 < String.length name && is_lowercase name.[idx + 1] in
if is_lowercase prev || is_digit prev || (is_uppercase prev && next_is_lowercase) then
Buffer.add_char buffer '_'
in
String.iteri (fun idx ch ->
add_separator_if_needed idx ch;
let normalized =
if is_uppercase ch then Char.lowercase_ascii ch
else if is_lowercase ch || is_digit ch || ch = '_' then ch
else '_'
in
Buffer.add_char buffer normalized
) name;
Buffer.contents buffer

let generate_default_struct_ops_name instance_name =
let max_len = 15 in
let normalized = normalize_struct_ops_instance_name instance_name in
if String.length normalized <= max_len then normalized
else
let parts = List.filter (fun part -> part <> "") (String.split_on_char '_' normalized) in
match parts with
| [] -> String.sub normalized 0 max_len
| first :: rest ->
let abbreviated =
match rest with
| [] -> first
| _ ->
let initials = rest |> List.map (fun part -> String.make 1 part.[0]) |> String.concat "" in
first ^ "_" ^ initials
in
if String.length abbreviated <= max_len then abbreviated
else String.sub abbreviated 0 max_len

let should_lower_as_implicit_tail_call ctx name =
let is_function_pointer =
Hashtbl.mem ctx.function_parameters name ||
match Hashtbl.find_opt ctx.variable_types name with
| Some (IRFunctionPointer _) -> true
| _ -> false
in
if is_function_pointer || Hashtbl.mem ctx.helper_functions name then
false
else
match ctx.current_function, ctx.current_program_type with
| Some _, Some Ast.StructOps -> false
| Some current_func_name, Some _ ->
let caller_is_attributed =
try Symbol_table.lookup_function ctx.symbol_table current_func_name <> None
with _ -> false
in
let target_is_attributed =
try Symbol_table.lookup_function ctx.symbol_table name <> None
with _ -> false
in
caller_is_attributed && target_is_attributed
| _ -> false


(** Map struct names to their corresponding context types *)
let struct_name_to_context_type = function
Expand Down Expand Up @@ -1659,14 +1740,12 @@ and lower_statement ctx stmt =
(* Check if this is a simple function call that could be a tail call *)
(match callee_expr.expr_desc with
| Ast.Identifier name ->
(* Check if this is a helper function - if so, treat as regular call *)
if Hashtbl.mem ctx.helper_functions name then
let ret_val = lower_expression ctx expr in
IRReturnValue ret_val
else
(* This will be converted to tail call by tail call analyzer *)
if should_lower_as_implicit_tail_call ctx name then
let arg_vals = List.map (lower_expression ctx) args in
IRReturnCall (name, arg_vals)
else
let ret_val = lower_expression ctx expr in
IRReturnValue ret_val
| _ ->
(* Function pointer call - treat as regular return *)
let ret_val = lower_expression ctx expr in
Expand All @@ -1689,13 +1768,12 @@ and lower_statement ctx stmt =
(* Check if this is a simple function call that could be a tail call *)
(match callee_expr.expr_desc with
| Ast.Identifier name ->
(* Check if this is a helper function - if so, treat as regular call *)
if Hashtbl.mem ctx.helper_functions name then
let ret_val = lower_expression ctx return_expr in
IRReturnValue ret_val
else
if should_lower_as_implicit_tail_call ctx name then
let arg_vals = List.map (lower_expression ctx) args in
IRReturnCall (name, arg_vals)
else
let ret_val = lower_expression ctx return_expr in
IRReturnValue ret_val
| _ ->
(* Function pointer call - treat as regular return *)
let ret_val = lower_expression ctx return_expr in
Expand All @@ -1712,13 +1790,12 @@ and lower_statement ctx stmt =
| Ast.Call (callee_expr, args) ->
(match callee_expr.expr_desc with
| Ast.Identifier name ->
(* Check if this is a helper function - if so, treat as regular call *)
if Hashtbl.mem ctx.helper_functions name then
let ret_val = lower_expression ctx expr in
IRReturnValue ret_val
else
if should_lower_as_implicit_tail_call ctx name then
let arg_vals = List.map (lower_expression ctx) args in
IRReturnCall (name, arg_vals)
else
let ret_val = lower_expression ctx expr in
IRReturnValue ret_val
| _ ->
let ret_val = lower_expression ctx expr in
IRReturnValue ret_val)
Expand Down Expand Up @@ -1761,47 +1838,7 @@ and lower_statement ctx stmt =
(* Check if this is a simple function call that could be a tail call *)
(match callee_expr.expr_desc with
| Ast.Identifier name ->
(* Check if this should be a tail call *)
let should_be_tail_call =
(* First check if the identifier is a function parameter or variable (function pointer) *)
let is_function_pointer =
Hashtbl.mem ctx.function_parameters name ||
Hashtbl.mem ctx.variable_types name
in

if is_function_pointer then
(* Function pointer calls should never be tail calls *)
false
else
(* Check if we're in an attributed function context *)
match ctx.current_function with
| Some current_func_name ->
(* Check if caller is attributed (has eBPF attributes) *)
let caller_is_attributed =
try
let caller_symbol = Symbol_table.lookup_function ctx.symbol_table current_func_name in
(* TODO: Check if caller has eBPF attributes like @xdp, @tc, etc. *)
(* For now, assume attributed functions are defined in symbol table *)
caller_symbol <> None
with _ -> false
in

(* Check if target function is an attributed function *)
let target_is_attributed =
try
let target_symbol = Symbol_table.lookup_function ctx.symbol_table name in
(* TODO: Check if target has eBPF attributes like @xdp, @tc, etc. *)
(* For now, assume attributed functions are defined in symbol table *)
target_symbol <> None
with _ -> false
in

(* Only allow tail calls between attributed functions *)
caller_is_attributed && target_is_attributed
| None -> false
in

if should_be_tail_call then
if should_lower_as_implicit_tail_call ctx name then
(* Generate tail call instruction *)
let arg_vals = List.map (lower_expression ctx) args in
let tail_call_index = 0 in (* This will be set by tail call analyzer *)
Expand Down Expand Up @@ -2356,6 +2393,7 @@ let convert_match_return_calls_to_tail_calls ir_function =
(** Lower AST function to IR function *)
let lower_function ctx prog_name ?(program_type : program_type option = None) ?(func_target = None) (func_def : Ast.function_def) =
ctx.current_function <- Some func_def.func_name;
ctx.current_program_type <- program_type;

(* Reset for new function *)
Hashtbl.clear ctx.variable_types;
Expand Down Expand Up @@ -3125,6 +3163,19 @@ let lower_multi_program ast symbol_table source_name =
in
Some (field_name, field_val)
) impl_block.impl_items in
let ir_instance_fields =
if ast_struct_has_field ast kernel_struct_name "name" && not (impl_block_has_static_field impl_block "name") then
let generated_name = generate_default_struct_ops_name impl_block.impl_name in
let generated_name_val =
make_ir_value
(IRLiteral (StringLit generated_name))
(IRStr (String.length generated_name + 1))
impl_block.impl_pos
in
ir_instance_fields @ [("name", generated_name_val)]
else
ir_instance_fields
in
let ir_instance = make_ir_struct_ops_instance
impl_block.impl_name
kernel_struct_name
Expand Down
Loading