Skip to content

Commit 45a4d78

Browse files
authored
Operator overloading (#753)
This patch allows you to implement operator overloading in OSL -- for example, if you define a type like struct vector4 { float x, y, z, w; }; you can also define math operators in the following manner: vector4 __operator__add__ (vector4 a, vector4 b) { return vector4 (a.x+b.x a.y+b.y, a.z+b.z, a.w+b.w); } and then use plain old `+` as you would for any built-in types: vector4 a, b, c; c = a + b; // automatically calls __operator__add__(vector4,vector4) It's fairly straightforward: for any of the usual unary and binary operators, if a function called __operator__NAME__ is visible from the current scope and the types match, it is understood that the operator really means a function call to that special function. The valid NAME choices are add, sub, mul, div, neg, and so on. See the docs for full list, but it's pretty obvious. With great power comes great responsibility! Use this wisely -- that is, sparingly, and only when the operator is exactly analogous to how it works with the built-in types. Same general guidance about using it in C++.
1 parent 01615ad commit 45a4d78

9 files changed

Lines changed: 314 additions & 74 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ TESTSUITE ( and-or-not-synonyms aastep arithmetic array array-derivs array-range
252252
noise-gabor noise-gabor2d-filter noise-gabor3d-filter
253253
noise-perlin noise-uperlin noise-simplex noise-usimplex
254254
pnoise pnoise-cell pnoise-gabor pnoise-perlin pnoise-uperlin
255+
operator-overloading
255256
oslc-comma oslc-D
256257
oslc-err-arrayindex oslc-err-closuremul
257258
oslc-err-format oslc-err-intoverflow

src/doc/languagespec.tex

Lines changed: 111 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161
Editor: Larry Gritz \\
6262
\emph{lg@imageworks.com}
6363
}
64-
\date{{\large Date: 6 Feb 2017 \\
65-
(with corrections, 12 May 2017)
64+
\date{{\large Date: 16 May 2017 \\
65+
% (with corrections, 12 May 2017)
6666
}
6767
\bigskip
6868
\bigskip
@@ -886,7 +886,7 @@ \section{Preprocessor}
886886
\end{tabular}
887887

888888

889-
\chapter{Gross syntax, shader types, parameters, functions}
889+
\chapter{Gross syntax, shader types, parameters}
890890
\label{chap:grosssyntax}
891891

892892
The overall structure of a shader is as follows:
@@ -1336,48 +1336,6 @@ \section{Shader metadata}
13361336

13371337

13381338
\newpage
1339-
\section{Functions}
1340-
\label{sec:functions}
1341-
\index{functions!declarations}
1342-
1343-
You may define functions much like in C or C++.
1344-
1345-
\begin{quote}
1346-
\em
1347-
return-type ~ function-name ~ {\rm\cf (} optional-parameters {\rm\cf )} \\
1348-
\rm
1349-
{\cf \{ } \\
1350-
\em
1351-
\spc statements
1352-
1353-
{\cf \} }
1354-
\end{quote}
1355-
1356-
Parameters to functions are similar to shader parameters, except that
1357-
they do not permit initializers. A function call must pass values for
1358-
all formal parameters. Function parameters in \langname are all
1359-
\emph{passed by reference}, and are read-only within the body of the
1360-
function unless they are also designated as {\cf output} (in the same
1361-
manner as output shader parameters).
1362-
1363-
Like for shaders, statements inside functions may be actual executions
1364-
(assignments, function call, etc.), local variable declarations (visible
1365-
only from within the body of the function), or local function
1366-
declarations (callable only from within the body of the function).
1367-
1368-
The return type may be any simple data type, a {\cf struct}, or a {\cf
1369-
closure}. Functions may not return arrays. The return type may be
1370-
{\cf void}, indicating that the function does not return a value (and
1371-
should not contain a {\cf return} statement). A {\cf return} statement
1372-
inside the body of the function will halt execution of the function at
1373-
that point, and designates the value that will be returned (if not a
1374-
{\cf void} function).
1375-
1376-
Functions may be \emph{overloaded}. That is, multiple functions may be
1377-
defined to have the same name, as long as they have differently-typed
1378-
parameters, so that when the function is called the list of arguments
1379-
can disambiguate which version of the function is desired.
1380-
13811339
\section{Public methods}
13821340
\label{sec:publicmethods}
13831341
\indexapi{public}
@@ -2632,7 +2590,10 @@ \section{Control flow: {\cf if, while, do, for}}
26322590
and proceeds to the next iteration of the loop.
26332591

26342592
\section{Functions}
2635-
\label{sec:syntax:functions}
2593+
\label{sec:functions}
2594+
2595+
\subsection{Function calls}
2596+
\label{sec:syntax:functioncalls}
26362597
\index{function calls}
26372598

26382599
Function calls are very similar to C and related programming languages:
@@ -2651,8 +2612,111 @@ \section{Functions}
26512612
semantics, except if you pass the same variable as two separate
26522613
arguments to a function that modifies an argument's value.
26532614

2654-
Function definitions are described in detail in Section~\ref{sec:functions}.
26552615

2616+
\subsection{Function definitions}
2617+
\label{sec:syntax:functiondefinitions}
2618+
\index{functions!definitions}
2619+
2620+
You may define functions much like in C or C++.
2621+
2622+
\begin{quote}
2623+
\em
2624+
return-type ~ function-name ~ {\rm\cf (} optional-parameters {\rm\cf )} \\
2625+
\rm
2626+
{\cf \{ } \\
2627+
\em
2628+
\spc statements
2629+
2630+
{\cf \} }
2631+
\end{quote}
2632+
2633+
Parameters to functions are similar to shader parameters, except that
2634+
they do not permit initializers. A function call must pass values for
2635+
all formal parameters. Function parameters in \langname are all
2636+
\emph{passed by reference}, and are read-only within the body of the
2637+
function unless they are also designated as {\cf output} (in the same
2638+
manner as output shader parameters).
2639+
2640+
Like for shaders, statements inside functions may be actual executions
2641+
(assignments, function call, etc.), local variable declarations (visible
2642+
only from within the body of the function), or local function
2643+
declarations (callable only from within the body of the function).
2644+
2645+
The return type may be any simple data type, a {\cf struct}, or a {\cf
2646+
closure}. Functions may not return arrays. The return type may be
2647+
{\cf void}, indicating that the function does not return a value (and
2648+
should not contain a {\cf return} statement). A {\cf return} statement
2649+
inside the body of the function will halt execution of the function at
2650+
that point, and designates the value that will be returned (if not a
2651+
{\cf void} function).
2652+
2653+
Functions may be \emph{overloaded}. That is, multiple functions may be
2654+
defined to have the same name, as long as they have differently-typed
2655+
parameters, so that when the function is called the list of arguments
2656+
can disambiguate which version of the function is desired.
2657+
2658+
\subsection{Operator overloading}
2659+
\label{sec:syntax:operatoroverloading}
2660+
\index{operator overloading}
2661+
2662+
\NEW % 1.9
2663+
OSL permits \emph{operator overloading}, which is the practice of providing
2664+
a function that will be called when you use an operator like {\cf +} or
2665+
{\cf *}. This is especially handy when you use {\cf struct} to define
2666+
mathematical types and wish for the usual math operators to work with them.
2667+
Here is a typical example, which also shows the special naming convention
2668+
that allows operator overloading:
2669+
2670+
\begin{code}
2671+
struct vector4 {
2672+
float x, y, z, w;
2673+
};
2674+
2675+
vector4 __operator__add__ (vector4 a, vector4 b) {
2676+
return vector4 (a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w);
2677+
}
2678+
2679+
shader test ()
2680+
{
2681+
vector4 a = vector4 (.2, .3, .4, .5);
2682+
vector4 b = vector4 (1, 2, 3, 4);
2683+
2684+
vector4 c = a + b; // Will call __operator__add__(vector4,vector4)
2685+
printf ("a+b = %g %g %g %g\n", c.x, c.y, c.z, c.w);
2686+
}
2687+
\end{code}
2688+
2689+
\noindent The full list of these special function names is as follows (in
2690+
order of decreasing operator precedence):
2691+
2692+
\smallskip
2693+
2694+
\begin{tabular}{ p{0.5in} p{2in} p{1.75in}}
2695+
%op & overload function name & ~ \\
2696+
%\hline
2697+
{\cf \bfseries -} & {\cf __operator__neg__} & unary negation \\
2698+
{\cf \bfseries \textasciitilde} & {\cf __operator__compl__} & unary bitwise complement \\
2699+
{\cf \bfseries !} & {\cf __operator__not__} & unary boolean `not' \\[1.5ex]
2700+
{\cf \bfseries *} & {\cf __operator__mul__} & \\
2701+
{\cf \bfseries /} & {\cf __operator__div__} & \\
2702+
{\cf \bfseries \%} & {\cf __operator__mod__} & \\
2703+
{\cf \bfseries +} & {\cf __operator__add__} & \\
2704+
{\cf \bfseries -} & {\cf __operator__sub__} & \\[1.5ex]
2705+
{\cf \bfseries <<} & {\cf __operator__shl__} & \\
2706+
{\cf \bfseries >>} & {\cf __operator__shr__} & \\[1.5ex]
2707+
{\cf \bfseries <} & {\cf __operator__lt__} & \\
2708+
{\cf \bfseries <=} & {\cf __operator__le__} & \\
2709+
{\cf \bfseries >} & {\cf __operator__gt__} & \\
2710+
{\cf \bfseries >=} & {\cf __operator__ge__} & \\
2711+
{\cf \bfseries ==} & {\cf __operator__eq__} & \\
2712+
{\cf \bfseries !=} & {\cf __operator__ne__} & \\[1.5ex]
2713+
{\cf \bfseries \&} & {\cf __operator__bitand__} & \\
2714+
{\cf \bfseries \textasciicircum} & {\cf __operator__xor__} & \\
2715+
{\cf \bfseries |} & {\cf __operator__bitor__} & \\
2716+
%{\cf \bfseries \&\&} & {\cf __operator__and__} & boolean and \\
2717+
%{\cf \bfseries ||} & {\cf __operator__or__} & boolean or \\
2718+
%\hline
2719+
\end{tabular}
26562720

26572721

26582722
\section{Global variables}

src/liboslcomp/ast.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,18 @@ ASTassign_expression::opword () const
854854

855855

856856

857+
ASTunary_expression::ASTunary_expression (OSLCompilerImpl *comp, int op,
858+
ASTNode *expr)
859+
: ASTNode (unary_expression_node, comp, op, expr)
860+
{
861+
// Check for a user-overloaded function for this operator
862+
Symbol *sym = comp->symtab().find (ustring::format ("__operator__%s__", opword()));
863+
if (sym && sym->symtype() == SymTypeFunction)
864+
m_function_overload = (FunctionSymbol *)sym;
865+
}
866+
867+
868+
857869
const char *
858870
ASTunary_expression::childname (size_t i) const
859871
{
@@ -891,6 +903,22 @@ ASTunary_expression::opword () const
891903

892904

893905

906+
ASTbinary_expression::ASTbinary_expression (OSLCompilerImpl *comp, Operator op,
907+
ASTNode *left, ASTNode *right)
908+
: ASTNode (binary_expression_node, comp, op, left, right)
909+
{
910+
// Check for a user-overloaded function for this operator.
911+
// Disallow a few ops from overloading.
912+
if (op != And && op != Or) {
913+
ustring funcname = ustring::format ("__operator__%s__", opword());
914+
Symbol *sym = comp->symtab().find (funcname);
915+
if (sym && sym->symtype() == SymTypeFunction)
916+
m_function_overload = (FunctionSymbol *)sym;
917+
}
918+
}
919+
920+
921+
894922
const char *
895923
ASTbinary_expression::childname (size_t i) const
896924
{
@@ -985,13 +1013,14 @@ ASTtype_constructor::childname (size_t i) const
9851013

9861014

9871015
ASTfunction_call::ASTfunction_call (OSLCompilerImpl *comp, ustring name,
988-
ASTNode *args)
1016+
ASTNode *args, FunctionSymbol *funcsym)
9891017
: ASTNode (function_call_node, comp, 0, args), m_name(name),
9901018
m_argread(~1), // Default - all args are read except the first
9911019
m_argwrite(1), // Default - first arg only is written by the op
9921020
m_argtakesderivs(0) // Default - doesn't take derivs
9931021
{
994-
m_sym = comp->symtab().find (name);
1022+
// If we weren't passed a function symbol directly, look it up.
1023+
m_sym = funcsym ? funcsym : comp->symtab().find (name);
9951024
if (! m_sym) {
9961025
error ("function '%s' was not declared in this scope", name.c_str());
9971026
// FIXME -- would be fun to troll through the symtab and try to

src/liboslcomp/ast.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ class ASTNode : public OIIO::RefCnt {
159159
/// reference-counted pointer to *this.
160160
ref append (ref &x) { append (x.get()); return this; }
161161

162+
/// Detatch any 'next' nodes.
163+
void detach_next () { m_next.reset(); }
164+
162165
/// Concatenate ASTNode sequences A and B, returning a raw pointer to
163166
/// the concatenated sequence. This is robust to either A or B or
164167
/// both being NULL.
@@ -290,6 +293,11 @@ class ASTNode : public OIIO::RefCnt {
290293
///
291294
void typecheck_children (TypeSpec expected = TypeSpec());
292295

296+
/// Helper for check_arglist: simple case of checking a single arg type
297+
/// agaisnt the front of the formals list (which will be advanced).
298+
bool check_simple_arg (const TypeSpec &argtype,
299+
const char * &formals, bool coerce);
300+
293301
/// Type check a list (whose head is given by 'arg' against the list
294302
/// of expected types given in encoded form by 'formals'.
295303
bool check_arglist (const char *funcname, ref arg,
@@ -731,9 +739,7 @@ class ASTassign_expression : public ASTNode
731739
class ASTunary_expression : public ASTNode
732740
{
733741
public:
734-
ASTunary_expression (OSLCompilerImpl *comp, int op, ASTNode *expr)
735-
: ASTNode (unary_expression_node, comp, op, expr)
736-
{ }
742+
ASTunary_expression (OSLCompilerImpl *comp, int op, ASTNode *expr);
737743

738744
const char *nodetypename () const { return "unary_expression"; }
739745
const char *childname (size_t i) const;
@@ -743,6 +749,8 @@ class ASTunary_expression : public ASTNode
743749
Symbol *codegen (Symbol *dest = NULL);
744750

745751
ref expr () const { return child (0); }
752+
private:
753+
FunctionSymbol *m_function_overload = nullptr;
746754
};
747755

748756

@@ -751,9 +759,7 @@ class ASTbinary_expression : public ASTNode
751759
{
752760
public:
753761
ASTbinary_expression (OSLCompilerImpl *comp, Operator op,
754-
ASTNode *left, ASTNode *right)
755-
: ASTNode (binary_expression_node, comp, op, left, right)
756-
{ }
762+
ASTNode *left, ASTNode *right);
757763

758764
const char *nodetypename () const { return "binary_expression"; }
759765
const char *childname (size_t i) const;
@@ -765,6 +771,8 @@ class ASTbinary_expression : public ASTNode
765771
ref left () const { return child (0); }
766772
ref right () const { return child (1); }
767773
private:
774+
FunctionSymbol *m_function_overload = nullptr;
775+
768776
// Special code generation for short-circuiting logical ops
769777
Symbol *codegen_logic (Symbol *dest);
770778
};
@@ -853,7 +861,8 @@ class ASTtype_constructor : public ASTNode
853861
class ASTfunction_call : public ASTNode
854862
{
855863
public:
856-
ASTfunction_call (OSLCompilerImpl *comp, ustring name, ASTNode *args);
864+
ASTfunction_call (OSLCompilerImpl *comp, ustring name, ASTNode *args,
865+
FunctionSymbol *funcsym = nullptr);
857866
const char *nodetypename () const { return "function_call"; }
858867
const char *childname (size_t i) const;
859868
const char *opname () const;

src/liboslcomp/codegen.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,15 @@ ASTunary_expression::codegen (Symbol *dest)
13601360
{
13611361
// Code generation for unary expressions (-x, !x, etc.)
13621362

1363+
if (m_function_overload) {
1364+
// A little crazy, but we temporarily construct an ASTfunction_call
1365+
// in order to codegen this overloaded operator.
1366+
ustring funcname = ustring::format ("__operator__%s__", opword());
1367+
ASTfunction_call fc (m_compiler, funcname, expr().get(), m_function_overload);
1368+
fc.typecheck (typespec());
1369+
return dest = fc.codegen (dest);
1370+
}
1371+
13631372
if (m_op == Not) {
13641373
// Special case for logical ops
13651374
return expr()->codegen_int (NULL, true /*boolify*/, true /*invert*/);
@@ -1396,6 +1405,27 @@ ASTunary_expression::codegen (Symbol *dest)
13961405
Symbol *
13971406
ASTbinary_expression::codegen (Symbol *dest)
13981407
{
1408+
if (m_function_overload) {
1409+
// A little crazy, but we temporarily construct an ASTfunction_call
1410+
// in order to codegen this overloaded operator. Slightly tricky
1411+
// is that we need to concatenate our left and right arguments into
1412+
// an arg list.
1413+
ustring funcname = ustring::format ("__operator__%s__", opword());
1414+
if (left()->nextptr() || right()->nextptr()) {
1415+
error ("Overloaded %s cannot be passed arguments %s and %s",
1416+
funcname, left()->nodetypename(), right()->nodetypename());
1417+
return dest;
1418+
}
1419+
ref args = left();
1420+
args->append (right().get());
1421+
ASTfunction_call fc (m_compiler, funcname, args.get(), m_function_overload);
1422+
fc.typecheck (typespec());
1423+
dest = fc.codegen (dest);
1424+
// now put things back the way we found them
1425+
left()->detach_next ();
1426+
return dest;
1427+
}
1428+
13991429
// Special case for logic ops that short-circuit
14001430
if (m_op == And || m_op == Or)
14011431
return codegen_logic (dest);

0 commit comments

Comments
 (0)