Skip to content

Commit 006d530

Browse files
issue/810 static compute graph infra
1 parent caa61e9 commit 006d530

23 files changed

Lines changed: 480 additions & 30 deletions

File tree

include/infinicore/context/context.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "../device.hpp"
44
#include "../memory.hpp"
55

6+
#include "../graph/graph.hpp"
7+
68
#include <infiniop.h>
79
#include <infinirt.h>
810

@@ -40,6 +42,12 @@ void destroyEvent(infinirtEvent_t event);
4042
float elapsedTime(infinirtEvent_t start, infinirtEvent_t end);
4143
void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event);
4244

45+
// Graph recording APIs
46+
bool isGraphRecording();
47+
void startGraphRecording();
48+
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op);
49+
std::shared_ptr<graph::Graph> stopGraphRecording();
50+
4351
} // namespace context
4452

4553
} // namespace infinicore

include/infinicore/graph/graph.hpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "../tensor.hpp"
7+
8+
namespace infinicore::graph {
9+
// Forward declarations
10+
class GraphManager;
11+
12+
class GraphTensor : public Tensor {
13+
public:
14+
GraphTensor(const Tensor &);
15+
};
16+
17+
class GraphOperator {
18+
19+
public:
20+
void run() const;
21+
~GraphOperator();
22+
23+
protected:
24+
using run_schema = void (*)(void *);
25+
using cleanup_schema = void (*)(void **);
26+
void *planned_meta_;
27+
run_schema runner_;
28+
cleanup_schema deleter_;
29+
};
30+
31+
class Graph {
32+
public:
33+
Graph() = default;
34+
~Graph() = default;
35+
36+
void run() const;
37+
38+
protected:
39+
void add_operator(std::shared_ptr<GraphOperator> op);
40+
41+
std::vector<std::shared_ptr<GraphOperator>> op_list_;
42+
43+
friend class GraphManager;
44+
};
45+
} // namespace infinicore::graph

include/infinicore/ops/gemm.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
#pragma once
22

33
#include "../device.hpp"
4+
#include "../graph/graph.hpp"
45
#include "common/op.hpp"
56

67
namespace infinicore::op {
78

8-
class Gemm {
9+
class Gemm : public graph::GraphOperator {
910
public:
1011
using schema = void (*)(Tensor, Tensor, Tensor, float, float);
12+
using plan_schema = void *(*)(Tensor, Tensor, Tensor, float, float);
13+
14+
Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta);
15+
1116
static void execute(Tensor c, Tensor a, Tensor b, float alpha, float beta);
17+
1218
static common::OpDispatcher<schema> &dispatcher();
19+
static common::OpDispatcher<plan_schema> &plan_dispatcher();
20+
static common::OpDispatcher<run_schema> &run_dispatcher();
21+
static common::OpDispatcher<cleanup_schema> &cleanup_dispatcher();
1322
};
1423

1524
Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);

include/infinicore/tensor.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
133133

134134
void debug() const;
135135

136+
Tensor to_blob() const;
137+
136138
///
137139
/// Data Transfer APIs
138140
///
@@ -294,7 +296,7 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
294296

295297
friend class Tensor;
296298

297-
private:
299+
protected:
298300
TensorMetaData meta_;
299301
TensorData data_;
300302
};

python/infinicore/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
get_device,
99
get_device_count,
1010
get_stream,
11+
is_graph_recording,
1112
set_device,
13+
start_graph_recording,
14+
stop_graph_recording,
1215
sync_device,
1316
sync_stream,
1417
)
@@ -80,6 +83,9 @@
8083
"set_device",
8184
"sync_device",
8285
"sync_stream",
86+
"is_graph_recording",
87+
"start_graph_recording",
88+
"stop_graph_recording",
8389
# Data Types.
8490
"bfloat16",
8591
"bool",

python/infinicore/context.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import infinicore.device
2+
from infinicore.graph import Graph
23
from infinicore.lib import _infinicore
34

45

@@ -49,3 +50,24 @@ def get_stream():
4950
stream: The current stream object
5051
"""
5152
return _infinicore.get_stream()
53+
54+
55+
def is_graph_recording():
56+
"""Check if the current graph is recording.
57+
58+
Returns:
59+
bool: True if the current graph is recording, False otherwise
60+
"""
61+
return _infinicore.is_graph_recording()
62+
63+
64+
def start_graph_recording(device=None):
65+
"""Start recording the current graph."""
66+
if device is not None:
67+
set_device(device)
68+
_infinicore.start_graph_recording()
69+
70+
71+
def stop_graph_recording():
72+
"""Stop recording the current graph."""
73+
return Graph(_infinicore.stop_graph_recording())

python/infinicore/graph.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from infinicore.lib import _infinicore
2+
3+
4+
class Graph:
5+
"""
6+
Python wrapper around a InfiniCore Graph instance.
7+
"""
8+
9+
def __init__(self, graph: _infinicore.Graph):
10+
if not isinstance(graph, _infinicore.Graph):
11+
raise TypeError("Expected _infinicore.Graph")
12+
self._graph = graph
13+
14+
def run(self):
15+
return self._graph.run()
16+
17+
def __repr__(self):
18+
return f"<Graph wrapper of {self._graph!r}>"

src/infinicore/context/allocators/pinnable_block_allocator.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "pinnable_block_allocator.hpp"
22

3+
#include "../context_impl.hpp"
4+
35
#include "../../utils.hpp"
46

57
#include <algorithm>

src/infinicore/context/allocators/pinnable_block_allocator.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
#include "memory_allocator.hpp"
44

5-
#include "../context_impl.hpp"
6-
75
#include <mutex>
86
#include <unordered_map>
97
#include <vector>
@@ -25,7 +23,7 @@ class PinnableBlockAllocator : public MemoryAllocator {
2523
};
2624

2725
public:
28-
explicit PinnableBlockAllocator(Device device);
26+
PinnableBlockAllocator(Device device);
2927
~PinnableBlockAllocator();
3028

3129
std::byte *allocate(size_t size) override;

src/infinicore/context/context_impl.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ void ContextImpl::setDevice(Device device) {
3939
return;
4040
}
4141

42+
if (getCurrentRuntime()->isGraphRecording()) {
43+
spdlog::warn("Switching device runtime during graph recording may break the graph!");
44+
}
45+
4246
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
4347
// Lazy initialization of runtime if never set before.
4448
runtime_table_[int(device.getType())][device.getIndex()] = std::unique_ptr<Runtime>(new Runtime(device));
@@ -178,6 +182,21 @@ void streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
178182
ContextImpl::singleton().getCurrentRuntime()->streamWaitEvent(stream, event);
179183
}
180184

185+
bool isGraphRecording() {
186+
return ContextImpl::singleton().getCurrentRuntime()->isGraphRecording();
187+
}
188+
189+
void startGraphRecording() {
190+
ContextImpl::singleton().getCurrentRuntime()->startGraphRecording();
191+
}
192+
193+
void addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
194+
ContextImpl::singleton().getCurrentRuntime()->addGraphOperator(op);
195+
}
196+
197+
std::shared_ptr<graph::Graph> stopGraphRecording() {
198+
return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording();
199+
}
181200
} // namespace context
182201

183202
} // namespace infinicore

0 commit comments

Comments
 (0)