Skip to content

Commit 4ca601d

Browse files
committed
implemented Pillar 9: Hardware-Accelerated Math.
1 parent c3dd289 commit 4ca601d

6 files changed

Lines changed: 108 additions & 5 deletions

File tree

include/ast.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ typedef enum {
6060
STMT_DISTRIBUTED_DECL,
6161
STMT_DISTRIBUTED_DECL,
6262
STMT_MODEL_DECL,
63-
STMT_QUANTUM_BLOCK
63+
STMT_MODEL_DECL,
64+
STMT_QUANTUM_BLOCK,
65+
STMT_GPU_BLOCK
6466
} StmtType;
6567

6668
// --- List Structures ---
@@ -174,6 +176,7 @@ typedef struct { char *name; StringList *capabilities; } NodeDeclStmt;
174176
typedef struct { char *name; StmtList *fields; } DistributedDeclStmt;
175177
typedef struct { char *name; char *architecture; StmtList *body; } ModelDeclStmt;
176178
typedef struct { StmtList *body; } QuantumBlockStmt;
179+
typedef struct { char *kernelName; StmtList *body; } GPUBlockStmt;
177180

178181
struct Stmt {
179182
StmtType type;
@@ -197,6 +200,7 @@ struct Stmt {
197200
DistributedDeclStmt distributed_decl;
198201
ModelDeclStmt model_decl;
199202
QuantumBlockStmt quantum_block;
203+
GPUBlockStmt gpu_block;
200204
} as;
201205
};
202206

@@ -249,6 +253,7 @@ Stmt *createNodeDeclStmt(const char *name, StringList *capabilities, int line, i
249253
Stmt *createDistributedDeclStmt(const char *name, StmtList *fields, int line, int column);
250254
Stmt *createModelDeclStmt(const char *name, const char *architecture, StmtList *body, int line, int column);
251255
Stmt *createQuantumBlockStmt(StmtList *body, int line, int column);
256+
Stmt *createGPUBlockStmt(const char *kernelName, StmtList *body, int line, int column);
252257

253258
ExprList *createExprList();
254259
void appendExpr(ExprList *list, Expr *expr);

include/scanner.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,13 @@ typedef enum {
154154
TOKEN_QUBIT,
155155
TOKEN_SUPERPOSE,
156156
TOKEN_ENTANGLE,
157+
TOKEN_SUPERPOSE,
158+
TOKEN_ENTANGLE,
157159
TOKEN_MEASURE,
160+
TOKEN_TENSOR,
161+
TOKEN_MATRIX,
162+
TOKEN_GPU,
163+
TOKEN_KERNEL,
158164

159165
TOKEN_ERROR,
160166
TOKEN_EOF

src/compiler/lexer/scanner.c

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,13 @@ static PxTokenType identifierType(Scanner *scanner) {
298298
}
299299
}
300300
break;
301+
case 'g':
302+
if (scanner->current - scanner->start > 1) {
303+
switch(scanner->start[1]) {
304+
case 'p': return checkKeyword(scanner, 2, 1, "u", TOKEN_GPU); // gpu
305+
}
306+
}
307+
break;
301308
case 'f':
302309
if (scanner->current - scanner->start > 1) {
303310
switch (scanner->start[1]) {
@@ -350,14 +357,26 @@ static PxTokenType identifierType(Scanner *scanner) {
350357
return checkKeyword(scanner, 2, 0, "", TOKEN_IS);
351358
}
352359
}
353-
break;
360+
case 'k':
361+
if (scanner->current - scanner->start > 1) {
362+
switch(scanner->start[1]) {
363+
case 'e': return checkKeyword(scanner, 2, 4, "rnel", TOKEN_KERNEL); // kernel
364+
}
365+
}
366+
break;
354367
case 'l':
355368
return checkKeyword(scanner, 1, 2, "et", TOKEN_LET);
356369
case 'm': // Added for 'match' and 'model'
357370
if (scanner->current - scanner->start > 1) {
358371
switch (scanner->start[1]) {
359372
case 'a':
360-
return checkKeyword(scanner, 2, 3, "tch", TOKEN_MATCH);
373+
if (scanner->current - scanner->start > 2) {
374+
if (scanner->start[2] == 't' && scanner->current - scanner->start > 3 && scanner->start[3] == 'r') {
375+
return checkKeyword(scanner, 4, 2, "ix", TOKEN_MATRIX); // matrix
376+
}
377+
return checkKeyword(scanner, 2, 3, "tch", TOKEN_MATCH);
378+
}
379+
break;
361380
case 'o':
362381
return checkKeyword(scanner, 2, 3, "del", TOKEN_MODEL); // model
363382
case 'u':
@@ -400,7 +419,7 @@ static PxTokenType identifierType(Scanner *scanner) {
400419
}
401420
break;
402421
case 'r':
403-
if (scanner->current - scanner->start > 2) {
422+
if (scanner->current - scanner->start > 2) {
404423
switch(scanner->start[2]) {
405424
case 'e': // predict
406425
return checkKeyword(scanner, 3, 4, "dict", TOKEN_PREDICT);
@@ -486,7 +505,13 @@ static PxTokenType identifierType(Scanner *scanner) {
486505
case 't':
487506
if (scanner->current - scanner->start > 1) {
488507
switch (scanner->start[1]) {
489-
case 'h':
508+
case 'e':
509+
if (scanner->current - scanner->start > 2) {
510+
if (scanner->start[2] == 'm') return checkKeyword(scanner, 3, 5, "poral", TOKEN_TEMPORAL);
511+
if (scanner->start[2] == 'n') return checkKeyword(scanner, 3, 3, "sor", TOKEN_TENSOR);
512+
}
513+
break;
514+
case 'h':
490515
if (scanner->current - scanner->start > 2) {
491516
switch (scanner->start[2]) {
492517
case 'i':

src/compiler/parser/ast.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,16 @@ Stmt *createQuantumBlockStmt(StmtList *body, int line, int column) {
635635
return stmt;
636636
}
637637

638+
Stmt *createGPUBlockStmt(const char *kernelName, StmtList *body, int line, int column) {
639+
Stmt *stmt = ALLOCATE(Stmt, 1);
640+
stmt->type = STMT_GPU_BLOCK;
641+
stmt->line = line;
642+
stmt->column = column;
643+
stmt->as.gpu_block.kernelName = kernelName ? strdup(kernelName) : NULL;
644+
stmt->as.gpu_block.body = body;
645+
return stmt;
646+
}
647+
638648
// --- Free Functions ---
639649

640650
void freeExpr(Expr *expr) {
@@ -845,6 +855,10 @@ void freeStmt(Stmt *stmt) {
845855
case STMT_QUANTUM_BLOCK:
846856
if(stmt->as.quantum_block.body) freeStmtList(stmt->as.quantum_block.body);
847857
break;
858+
case STMT_GPU_BLOCK:
859+
if(stmt->as.gpu_block.kernelName) free(stmt->as.gpu_block.kernelName);
860+
if(stmt->as.gpu_block.body) freeStmtList(stmt->as.gpu_block.body);
861+
break;
848862
}
849863

850864
FREE(Stmt, stmt);

src/compiler/parser/parser.c

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ static Stmt *nodeDecl(Parser *p); // Forward
2727
static Stmt *distributedDecl(Parser *p); // Forward
2828
static Stmt *modelDecl(Parser *p); // Forward
2929
static Stmt *quantumStmt(Parser *p); // Forward
30+
static Stmt *gpuStmt(Parser *p); // Forward
3031
static Stmt *useDecl(Parser *p);
3132
static Stmt *forStmt(Parser *p);
3233
static Stmt *ifStmt(Parser *p);
@@ -273,6 +274,8 @@ static Stmt *declaration(Parser *p) {
273274
return modelDecl(p);
274275
if (match(p, 1, TOKEN_QUANTUM))
275276
return quantumStmt(p);
277+
if (match(p, 1, TOKEN_GPU))
278+
return gpuStmt(p);
276279

277280
return statement(p);
278281
}
@@ -1179,3 +1182,41 @@ static Stmt *distributedDecl(Parser *p) {
11791182
free(name);
11801183
return stmt;
11811184
}
1185+
1186+
static Stmt *modelDecl(Parser *p) {
1187+
Token nameTok = consume(p, TOKEN_IDENTIFIER, "Expect model name.");
1188+
char *name = tokenToString(nameTok);
1189+
1190+
char *arch = NULL;
1191+
1192+
consume(p, TOKEN_LEFT_BRACE, "Expect '{'.");
1193+
StmtList *body = block(p);
1194+
1195+
Stmt *stmt = createModelDeclStmt(name, arch, body, nameTok.line, 0);
1196+
free(name);
1197+
return stmt;
1198+
}
1199+
1200+
static Stmt *quantumStmt(Parser *p) {
1201+
Token keyword = p->previous;
1202+
consume(p, TOKEN_LEFT_BRACE, "Expect '{' after quantum.");
1203+
StmtList *body = block(p);
1204+
return createQuantumBlockStmt(body, keyword.line, 0);
1205+
}
1206+
1207+
static Stmt *gpuStmt(Parser *p) {
1208+
Token keyword = p->previous;
1209+
char *kernelName = NULL;
1210+
1211+
if (match(p, 1, TOKEN_KERNEL)) {
1212+
Token nameTok = consume(p, TOKEN_IDENTIFIER, "Expect kernel name.");
1213+
kernelName = tokenToString(nameTok);
1214+
}
1215+
1216+
consume(p, TOKEN_LEFT_BRACE, "Expect '{' after gpu (kernel ...).");
1217+
StmtList *body = block(p);
1218+
1219+
Stmt *stmt = createGPUBlockStmt(kernelName, body, keyword.line, 0);
1220+
if(kernelName) free(kernelName);
1221+
return stmt;
1222+
}

src/compiler/type_checker.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,18 @@ static void checkStmt(TypeChecker* checker, Stmt* stmt) {
612612
break;
613613
}
614614

615+
case STMT_GPU_BLOCK: {
616+
beginScope(checker);
617+
StmtList* body = stmt->as.gpu_block.body;
618+
if (body) {
619+
for(int i=0; i<body->count; i++) {
620+
checkStmt(checker, body->items[i]);
621+
}
622+
}
623+
endScope(checker);
624+
break;
625+
}
626+
615627
default:
616628
break;
617629
}

0 commit comments

Comments
 (0)