Skip to content

Commit

Permalink
[lang] Support ti.FieldsBuilder() (#2501)
Browse files Browse the repository at this point in the history
* temp

* dynamic wip (runtime get root)

* dynamic snode

* clean up comments

* clean up commented code

* hide root in GetRootStmt

* resolve conversations

* remove num_roots

* add const

* edit test

* Auto Format

* add default value

* add default value to pass tests

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
ljcc0930 and taichi-gardener authored Jul 8, 2021
1 parent dedd976 commit a584008
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 100 deletions.
9 changes: 0 additions & 9 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,6 @@ def field(dtype, shape=None, offset=None, needs_grad=False):

assert (offset is not None and shape is None
) == False, f'The shape cannot be None when offset is being set'
'''
if get_runtime().materialized:
raise RuntimeError(
"No new variables can be declared after materialization, i.e. kernel invocations "
"or Python-scope field accesses. I.e., data layouts must be specified before "
"any computation. Try appending ti.init() or ti.reset() "
"right after 'import taichi as ti' if you are using Jupyter notebook or Blender."
)
'''

del _taichi_skip_traceback

Expand Down
24 changes: 16 additions & 8 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1251,12 +1251,19 @@ llvm::Value *CodeGenLLVM::call(SNode *snode,
}

void CodeGenLLVM::visit(GetRootStmt *stmt) {
llvm_val[stmt] = builder->CreateBitCast(
get_root(),
llvm::PointerType::get(
StructCompilerLLVM::get_llvm_node_type(
module.get(), prog->get_snode_root(SNodeTree::kFirstID)),
0));
if (stmt->root() == nullptr)
llvm_val[stmt] = builder->CreateBitCast(
get_root(SNodeTree::kFirstID),
llvm::PointerType::get(
StructCompilerLLVM::get_llvm_node_type(
module.get(), prog->get_snode_root(SNodeTree::kFirstID)),
0));
else
llvm_val[stmt] = builder->CreateBitCast(
get_root(stmt->root()->get_snode_tree_id()),
llvm::PointerType::get(
StructCompilerLLVM::get_llvm_node_type(module.get(), stmt->root()),
0));
}

void CodeGenLLVM::visit(BitExtractStmt *stmt) {
Expand Down Expand Up @@ -2011,8 +2018,9 @@ llvm::Type *CodeGenLLVM::get_xlogue_function_type() {
get_xlogue_argument_types(), false);
}

llvm::Value *CodeGenLLVM::get_root() {
return create_call("LLVMRuntime_get_root", {get_runtime()});
llvm::Value *CodeGenLLVM::get_root(int snode_tree_id) {
return create_call("LLVMRuntime_get_roots",
{get_runtime(), tlctx->get_constant(snode_tree_id)});
}

llvm::Value *CodeGenLLVM::get_runtime() {
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Type *get_xlogue_function_type();

llvm::Value *get_root();
llvm::Value *get_root(int snode_tree_id);

llvm::Value *get_runtime();

Expand Down
1 change: 1 addition & 0 deletions taichi/inc/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
constexpr int taichi_max_num_indices = 8;
constexpr int taichi_max_num_args = 8;
constexpr int taichi_max_num_snodes = 1024;
constexpr int taichi_max_num_snode_trees = 32;
constexpr int taichi_max_gpu_block_dim = 1024;
constexpr std::size_t taichi_global_tmp_buffer_size = 1024 * 1024;
constexpr int taichi_max_num_mem_requests = 1024 * 64;
Expand Down
8 changes: 8 additions & 0 deletions taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,12 @@ SNode *SNode::get_grad() const {
return grad_info->grad_snode();
}

void SNode::set_snode_tree_id(int id) {
snode_tree_id_ = id;
}

int SNode::get_snode_tree_id() {
return snode_tree_id_;
}

TLANG_NAMESPACE_END
9 changes: 9 additions & 0 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,15 @@ class SNode {
void begin_shared_exp_placement();

void end_shared_exp_placement();

// SNodeTree part

void set_snode_tree_id(int id);

int get_snode_tree_id();

private:
int snode_tree_id_{0};
};

} // namespace lang
Expand Down
16 changes: 14 additions & 2 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,16 +912,28 @@ class BitExtractStmt : public Stmt {
*/
class GetRootStmt : public Stmt {
public:
GetRootStmt() {
GetRootStmt(SNode *root = nullptr) : root_(root) {
if (this->root_ != nullptr) {
while (this->root_->parent) {
this->root_ = this->root_->parent;
}
}
TI_STMT_REG_FIELDS;
}

bool has_global_side_effect() const override {
return false;
}

TI_STMT_DEF_FIELDS(ret_type);
TI_STMT_DEF_FIELDS(ret_type, root_);
TI_DEFINE_ACCEPT_AND_CLONE

SNode *root() {
return root_;
}

private:
SNode *root_;
};

/**
Expand Down
11 changes: 8 additions & 3 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ void Program::initialize_llvm_runtime_snodes(const SNodeTree *tree,
TI_TRACE("Allocating data structure of size {} bytes", scomp->root_size);
runtime_jit->call<void *, std::size_t, int, int>(
"runtime_initialize_snodes", llvm_runtime, scomp->root_size, root_id,
(int)snodes.size());
(int)snodes.size(), tree->id());
for (int i = 0; i < (int)snodes.size(); i++) {
if (is_gc_able(snodes[i]->type)) {
std::size_t node_size;
Expand Down Expand Up @@ -430,6 +430,7 @@ void Program::initialize_llvm_runtime_snodes(const SNodeTree *tree,
int Program::add_snode_tree(std::unique_ptr<SNode> root) {
const int id = snode_trees_.size();
auto tree = std::make_unique<SNodeTree>(id, std::move(root));
tree->root()->set_snode_tree_id(id);
materialize_snode_tree(tree.get());
snode_trees_.push_back(std::move(tree));
return id;
Expand Down Expand Up @@ -655,7 +656,9 @@ void Program::visualize_layout(const std::string &fn) {
emit("]");
};

visit(get_snode_root(SNodeTree::kFirstID));
for (auto &a : snode_trees_) {
visit(a->root());
}

auto tail = R"(
\end{tikzpicture}
Expand Down Expand Up @@ -891,7 +894,9 @@ void Program::print_memory_profiler_info() {
}
};

visit(get_snode_root(SNodeTree::kFirstID), /*depth=*/0);
for (auto &a : snode_trees_) {
visit(a->root(), /*depth=*/0);
}

auto total_requested_memory = runtime_query<std::size_t>(
"LLVMRuntime_get_total_requested_memory", llvm_runtime);
Expand Down
28 changes: 16 additions & 12 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,10 @@ struct LLVMRuntime {
host_printf_type host_printf;
host_vsnprintf_type host_vsnprintf;
Ptr program;
Ptr root;
size_t root_mem_size;

Ptr roots[taichi_max_num_snode_trees];
size_t root_mem_sizes[taichi_max_num_snode_trees];

Ptr thread_pool;
parallel_for_type parallel_for;
ListManager *element_lists[taichi_max_num_snodes];
Expand Down Expand Up @@ -573,8 +575,8 @@ struct LLVMRuntime {
// TODO: are these necessary?
STRUCT_FIELD_ARRAY(LLVMRuntime, element_lists);
STRUCT_FIELD_ARRAY(LLVMRuntime, node_allocators);
STRUCT_FIELD(LLVMRuntime, root);
STRUCT_FIELD(LLVMRuntime, root_mem_size);
STRUCT_FIELD_ARRAY(LLVMRuntime, roots);
STRUCT_FIELD_ARRAY(LLVMRuntime, root_mem_sizes);
STRUCT_FIELD(LLVMRuntime, temporaries);
STRUCT_FIELD(LLVMRuntime, assert_failed);
STRUCT_FIELD(LLVMRuntime, host_printf);
Expand Down Expand Up @@ -890,14 +892,15 @@ void runtime_initialize(

void runtime_initialize_snodes(LLVMRuntime *runtime,
std::size_t root_size,
int root_id,
int num_snodes) {
const int root_id,
const int num_snodes,
const int snode_tree_id) {
// For Metal runtime, we have to make sure that both the beginning address
// and the size of the root buffer memory are aligned to page size.
runtime->root_mem_size =
runtime->root_mem_sizes[snode_tree_id] =
taichi::iroundup((size_t)root_size, taichi_page_size);
runtime->root =
runtime->allocate_aligned(runtime->root_mem_size, taichi_page_size);
runtime->roots[snode_tree_id] = runtime->allocate_aligned(
runtime->root_mem_sizes[snode_tree_id], taichi_page_size);
// runtime->request_allocate_aligned ready to use
// initialize the root node element list
for (int i = root_id; i < root_id + num_snodes; i++) {
Expand All @@ -908,7 +911,7 @@ void runtime_initialize_snodes(LLVMRuntime *runtime,
Element elem;
elem.loop_bounds[0] = 0;
elem.loop_bounds[1] = 1;
elem.element = runtime->root;
elem.element = runtime->roots[snode_tree_id];
for (int i = 0; i < taichi_max_num_indices; i++) {
elem.pcoord.val[i] = 0;
}
Expand Down Expand Up @@ -1743,9 +1746,10 @@ i32 wasm_materialize(Context *context) {
(RandState *)((size_t)context->runtime + sizeof(LLVMRuntime));
// set random seed to (1, 0, 0, 0)
context->runtime->rand_states[0].x = 1;
context->runtime->root =
// TODO: remove hard coding on root id 0(SNodeTree::kFirstID)
context->runtime->roots[0] =
(Ptr)((size_t)context->runtime->rand_states + sizeof(RandState));
return (i32)(size_t)context->runtime->root;
return (i32)(size_t)context->runtime->roots[0];
}
}

Expand Down
7 changes: 6 additions & 1 deletion taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,12 @@ class IRPrinter : public IRVisitor {
}

void visit(GetRootStmt *stmt) override {
print("{}{} = get root", stmt->type_hint(), stmt->name());
if (stmt->root() == nullptr)
print("{}{} = get root nullptr", stmt->type_hint(), stmt->name());
else
print("{}{} = get root [{}][{}]", stmt->type_hint(), stmt->name(),
stmt->root()->get_node_type_name_hinted(),
stmt->root()->type_name());
}

void visit(SNodeLookupStmt *stmt) override {
Expand Down
5 changes: 4 additions & 1 deletion taichi/transforms/scalar_pointer_lowerer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ void ScalarPointerLowerer::run() {
}
}

Stmt *last = lowered_->push_back<GetRootStmt>();
if (path_length_ == 0)
return;

Stmt *last = lowered_->push_back<GetRootStmt>(snodes_[0]);
for (int i = 0; i < path_length_; i++) {
auto *snode = snodes_[i];
// TODO: Explain this condition
Expand Down
51 changes: 34 additions & 17 deletions tests/python/test_fields_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,57 @@ def func():


@ti.test(arch=[ti.cpu, ti.cuda])
def test_fields_builder1():
def test_fields_builder_dense():
n = 5
x = ti.field(ti.f32, [n])

fb1 = ti.FieldsBuilder()
x = ti.field(ti.f32)
fb1.dense(ti.i, n).place(x)
fb1.finalize()

@ti.kernel
def func1():
for i in range(n):
x[i] = i * 2
x[i] = i * 3

func1()
for i in range(n):
assert x[i] == i * 2
assert x[i] == i * 3

fb = ti.FieldsBuilder()
fb2 = ti.FieldsBuilder()
y = ti.field(ti.f32)
fb.dense(ti.i, n).place(y)
fb.finalize()
fb2.dense(ti.i, n).place(y)
z = ti.field(ti.f32)
fb2.dense(ti.i, n).place(z)
fb2.finalize()

@ti.kernel
def func2():
for i in range(n):
y[i] = i // 2
x[i] = i * 2
for i in range(n):
y[i] = i + 5
for i in range(n):
z[i] = i + 10

func2()
for i in range(n):
assert y[i] == i // 2
assert x[i] == i * 2
assert y[i] == i + 5
assert z[i] == i + 10

func1()
for i in range(n):
assert x[i] == i * 2
assert x[i] == i * 3


@ti.test(arch=[ti.cpu, ti.cuda])
def test_fields_builder2():
# TODO: x, y share the same memory location
pass
'''
def test_fields_builder_pointer():
n = 5

fb1 = ti.FieldsBuilder()
x = ti.field(ti.f32)
fb1.dense(ti.i, n).place(x)
fb1.pointer(ti.i, n).place(x)
fb1.finalize()

@ti.kernel
Expand All @@ -83,7 +92,9 @@ def func1():

fb2 = ti.FieldsBuilder()
y = ti.field(ti.f32)
fb2.dense(ti.i, n).place(y)
fb2.pointer(ti.i, n).place(y)
z = ti.field(ti.f32)
fb2.pointer(ti.i, n).place(z)
fb2.finalize()

@ti.kernel
Expand All @@ -92,9 +103,15 @@ def func2():
x[i] = i * 2
for i in range(n):
y[i] = i + 5
for i in range(n):
z[i] = i + 10

func2()
for i in range(n):
assert x[i] == i * 2
assert y[i] == i + 5
'''
assert z[i] == i + 10

func1()
for i in range(n):
assert x[i] == i * 3
Loading

0 comments on commit a584008

Please sign in to comment.