diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 18bf2bbbb..e3d010be0 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -442,15 +442,6 @@ def field(dtype, shape=None, offset=None, needs_grad=False): assert (offset is not None and shape is None ) == False, f'The shape cannot be None when offset is being set' - ''' - if get_runtime().materialized: - raise RuntimeError( - "No new variables can be declared after materialization, i.e. kernel invocations " - "or Python-scope field accesses. I.e., data layouts must be specified before " - "any computation. Try appending ti.init() or ti.reset() " - "right after 'import taichi as ti' if you are using Jupyter notebook or Blender." - ) - ''' del _taichi_skip_traceback diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 7b6c6e5b6..59d4b9335 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1251,12 +1251,19 @@ llvm::Value *CodeGenLLVM::call(SNode *snode, } void CodeGenLLVM::visit(GetRootStmt *stmt) { - llvm_val[stmt] = builder->CreateBitCast( - get_root(), - llvm::PointerType::get( - StructCompilerLLVM::get_llvm_node_type( - module.get(), prog->get_snode_root(SNodeTree::kFirstID)), - 0)); + if (stmt->root() == nullptr) + llvm_val[stmt] = builder->CreateBitCast( + get_root(SNodeTree::kFirstID), + llvm::PointerType::get( + StructCompilerLLVM::get_llvm_node_type( + module.get(), prog->get_snode_root(SNodeTree::kFirstID)), + 0)); + else + llvm_val[stmt] = builder->CreateBitCast( + get_root(stmt->root()->get_snode_tree_id()), + llvm::PointerType::get( + StructCompilerLLVM::get_llvm_node_type(module.get(), stmt->root()), + 0)); } void CodeGenLLVM::visit(BitExtractStmt *stmt) { @@ -2011,8 +2018,9 @@ llvm::Type *CodeGenLLVM::get_xlogue_function_type() { get_xlogue_argument_types(), false); } -llvm::Value *CodeGenLLVM::get_root() { - return create_call("LLVMRuntime_get_root", {get_runtime()}); +llvm::Value *CodeGenLLVM::get_root(int snode_tree_id) { + return create_call("LLVMRuntime_get_roots", + {get_runtime(), tlctx->get_constant(snode_tree_id)}); } llvm::Value *CodeGenLLVM::get_runtime() { diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index e95cd366c..255f677bc 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -101,7 +101,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Type *get_xlogue_function_type(); - llvm::Value *get_root(); + llvm::Value *get_root(int snode_tree_id); llvm::Value *get_runtime(); diff --git a/taichi/inc/constants.h b/taichi/inc/constants.h index 7241da5da..66a4eb227 100644 --- a/taichi/inc/constants.h +++ b/taichi/inc/constants.h @@ -5,6 +5,7 @@ constexpr int taichi_max_num_indices = 8; constexpr int taichi_max_num_args = 8; constexpr int taichi_max_num_snodes = 1024; +constexpr int taichi_max_num_snode_trees = 32; constexpr int taichi_max_gpu_block_dim = 1024; constexpr std::size_t taichi_global_tmp_buffer_size = 1024 * 1024; constexpr int taichi_max_num_mem_requests = 1024 * 64; diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index b875b74bd..2bdf25156 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -200,4 +200,12 @@ SNode *SNode::get_grad() const { return grad_info->grad_snode(); } +void SNode::set_snode_tree_id(int id) { + snode_tree_id_ = id; +} + +int SNode::get_snode_tree_id() { + return snode_tree_id_; +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index 5b002a454..5740e0d27 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -293,6 +293,15 @@ class SNode { void begin_shared_exp_placement(); void end_shared_exp_placement(); + + // SNodeTree part + + void set_snode_tree_id(int id); + + int get_snode_tree_id(); + + private: + int snode_tree_id_{0}; }; } // namespace lang diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 209e88ac5..408d28e0d 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -912,7 +912,12 @@ class BitExtractStmt : public Stmt { */ class GetRootStmt : public Stmt { public: - GetRootStmt() { + GetRootStmt(SNode *root = nullptr) : root_(root) { + if (this->root_ != nullptr) { + while (this->root_->parent) { + this->root_ = this->root_->parent; + } + } TI_STMT_REG_FIELDS; } @@ -920,8 +925,15 @@ class GetRootStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type); + TI_STMT_DEF_FIELDS(ret_type, root_); TI_DEFINE_ACCEPT_AND_CLONE + + SNode *root() { + return root_; + } + + private: + SNode *root_; }; /** diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 47c719a24..5a02c1dd0 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -402,7 +402,7 @@ void Program::initialize_llvm_runtime_snodes(const SNodeTree *tree, TI_TRACE("Allocating data structure of size {} bytes", scomp->root_size); runtime_jit->call( "runtime_initialize_snodes", llvm_runtime, scomp->root_size, root_id, - (int)snodes.size()); + (int)snodes.size(), tree->id()); for (int i = 0; i < (int)snodes.size(); i++) { if (is_gc_able(snodes[i]->type)) { std::size_t node_size; @@ -430,6 +430,7 @@ void Program::initialize_llvm_runtime_snodes(const SNodeTree *tree, int Program::add_snode_tree(std::unique_ptr root) { const int id = snode_trees_.size(); auto tree = std::make_unique(id, std::move(root)); + tree->root()->set_snode_tree_id(id); materialize_snode_tree(tree.get()); snode_trees_.push_back(std::move(tree)); return id; @@ -655,7 +656,9 @@ void Program::visualize_layout(const std::string &fn) { emit("]"); }; - visit(get_snode_root(SNodeTree::kFirstID)); + for (auto &a : snode_trees_) { + visit(a->root()); + } auto tail = R"( \end{tikzpicture} @@ -891,7 +894,9 @@ void Program::print_memory_profiler_info() { } }; - visit(get_snode_root(SNodeTree::kFirstID), /*depth=*/0); + for (auto &a : snode_trees_) { + visit(a->root(), /*depth=*/0); + } auto total_requested_memory = runtime_query( "LLVMRuntime_get_total_requested_memory", llvm_runtime); diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index 13df86bf8..43cb5a1ef 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -525,8 +525,10 @@ struct LLVMRuntime { host_printf_type host_printf; host_vsnprintf_type host_vsnprintf; Ptr program; - Ptr root; - size_t root_mem_size; + + Ptr roots[taichi_max_num_snode_trees]; + size_t root_mem_sizes[taichi_max_num_snode_trees]; + Ptr thread_pool; parallel_for_type parallel_for; ListManager *element_lists[taichi_max_num_snodes]; @@ -573,8 +575,8 @@ struct LLVMRuntime { // TODO: are these necessary? STRUCT_FIELD_ARRAY(LLVMRuntime, element_lists); STRUCT_FIELD_ARRAY(LLVMRuntime, node_allocators); -STRUCT_FIELD(LLVMRuntime, root); -STRUCT_FIELD(LLVMRuntime, root_mem_size); +STRUCT_FIELD_ARRAY(LLVMRuntime, roots); +STRUCT_FIELD_ARRAY(LLVMRuntime, root_mem_sizes); STRUCT_FIELD(LLVMRuntime, temporaries); STRUCT_FIELD(LLVMRuntime, assert_failed); STRUCT_FIELD(LLVMRuntime, host_printf); @@ -890,14 +892,15 @@ void runtime_initialize( void runtime_initialize_snodes(LLVMRuntime *runtime, std::size_t root_size, - int root_id, - int num_snodes) { + const int root_id, + const int num_snodes, + const int snode_tree_id) { // For Metal runtime, we have to make sure that both the beginning address // and the size of the root buffer memory are aligned to page size. - runtime->root_mem_size = + runtime->root_mem_sizes[snode_tree_id] = taichi::iroundup((size_t)root_size, taichi_page_size); - runtime->root = - runtime->allocate_aligned(runtime->root_mem_size, taichi_page_size); + runtime->roots[snode_tree_id] = runtime->allocate_aligned( + runtime->root_mem_sizes[snode_tree_id], taichi_page_size); // runtime->request_allocate_aligned ready to use // initialize the root node element list for (int i = root_id; i < root_id + num_snodes; i++) { @@ -908,7 +911,7 @@ void runtime_initialize_snodes(LLVMRuntime *runtime, Element elem; elem.loop_bounds[0] = 0; elem.loop_bounds[1] = 1; - elem.element = runtime->root; + elem.element = runtime->roots[snode_tree_id]; for (int i = 0; i < taichi_max_num_indices; i++) { elem.pcoord.val[i] = 0; } @@ -1743,9 +1746,10 @@ i32 wasm_materialize(Context *context) { (RandState *)((size_t)context->runtime + sizeof(LLVMRuntime)); // set random seed to (1, 0, 0, 0) context->runtime->rand_states[0].x = 1; - context->runtime->root = + // TODO: remove hard coding on root id 0(SNodeTree::kFirstID) + context->runtime->roots[0] = (Ptr)((size_t)context->runtime->rand_states + sizeof(RandState)); - return (i32)(size_t)context->runtime->root; + return (i32)(size_t)context->runtime->roots[0]; } } diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index cd1895b74..6c68cd3e8 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -466,7 +466,12 @@ class IRPrinter : public IRVisitor { } void visit(GetRootStmt *stmt) override { - print("{}{} = get root", stmt->type_hint(), stmt->name()); + if (stmt->root() == nullptr) + print("{}{} = get root nullptr", stmt->type_hint(), stmt->name()); + else + print("{}{} = get root [{}][{}]", stmt->type_hint(), stmt->name(), + stmt->root()->get_node_type_name_hinted(), + stmt->root()->type_name()); } void visit(SNodeLookupStmt *stmt) override { diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp index b52b29e10..9a922452e 100644 --- a/taichi/transforms/scalar_pointer_lowerer.cpp +++ b/taichi/transforms/scalar_pointer_lowerer.cpp @@ -46,7 +46,10 @@ void ScalarPointerLowerer::run() { } } - Stmt *last = lowered_->push_back(); + if (path_length_ == 0) + return; + + Stmt *last = lowered_->push_back(snodes_[0]); for (int i = 0; i < path_length_; i++) { auto *snode = snodes_[i]; // TODO: Explain this condition diff --git a/tests/python/test_fields_builder.py b/tests/python/test_fields_builder.py index 927ad7b56..f315b4692 100644 --- a/tests/python/test_fields_builder.py +++ b/tests/python/test_fields_builder.py @@ -28,48 +28,57 @@ def func(): @ti.test(arch=[ti.cpu, ti.cuda]) -def test_fields_builder1(): +def test_fields_builder_dense(): n = 5 - x = ti.field(ti.f32, [n]) + + fb1 = ti.FieldsBuilder() + x = ti.field(ti.f32) + fb1.dense(ti.i, n).place(x) + fb1.finalize() @ti.kernel def func1(): for i in range(n): - x[i] = i * 2 + x[i] = i * 3 func1() for i in range(n): - assert x[i] == i * 2 + assert x[i] == i * 3 - fb = ti.FieldsBuilder() + fb2 = ti.FieldsBuilder() y = ti.field(ti.f32) - fb.dense(ti.i, n).place(y) - fb.finalize() + fb2.dense(ti.i, n).place(y) + z = ti.field(ti.f32) + fb2.dense(ti.i, n).place(z) + fb2.finalize() @ti.kernel def func2(): for i in range(n): - y[i] = i // 2 + x[i] = i * 2 + for i in range(n): + y[i] = i + 5 + for i in range(n): + z[i] = i + 10 func2() for i in range(n): - assert y[i] == i // 2 + assert x[i] == i * 2 + assert y[i] == i + 5 + assert z[i] == i + 10 func1() for i in range(n): - assert x[i] == i * 2 + assert x[i] == i * 3 @ti.test(arch=[ti.cpu, ti.cuda]) -def test_fields_builder2(): - # TODO: x, y share the same memory location - pass - ''' +def test_fields_builder_pointer(): n = 5 fb1 = ti.FieldsBuilder() x = ti.field(ti.f32) - fb1.dense(ti.i, n).place(x) + fb1.pointer(ti.i, n).place(x) fb1.finalize() @ti.kernel @@ -83,7 +92,9 @@ def func1(): fb2 = ti.FieldsBuilder() y = ti.field(ti.f32) - fb2.dense(ti.i, n).place(y) + fb2.pointer(ti.i, n).place(y) + z = ti.field(ti.f32) + fb2.pointer(ti.i, n).place(z) fb2.finalize() @ti.kernel @@ -92,9 +103,15 @@ def func2(): x[i] = i * 2 for i in range(n): y[i] = i + 5 + for i in range(n): + z[i] = i + 10 func2() for i in range(n): assert x[i] == i * 2 assert y[i] == i + 5 - ''' + assert z[i] == i + 10 + + func1() + for i in range(n): + assert x[i] == i * 3 diff --git a/tests/python/test_runtime.py b/tests/python/test_runtime.py index abaaf9601..624e184aa 100644 --- a/tests/python/test_runtime.py +++ b/tests/python/test_runtime.py @@ -131,52 +131,6 @@ def test_init_bad_arg(): ti.init(_test_mode=True, debug=True, foo_bar=233) -@ti.test(arch=ti.cpu) -def test_materialization_after_kernel(): - pass - ''' - x = ti.field(ti.f32, (3, 4)) - - @ti.kernel - def func(): - print(x[2, 3]) - - func() - - with pytest.raises(RuntimeError, match='declared after'): - y = ti.field(ti.f32, (2, 3)) - # ERROR: No new variable should be declared after kernel invocation! - ''' - - -@ti.test(arch=ti.cpu) -def test_materialization_after_access(): - pass - ''' - x = ti.field(ti.f32, (3, 4)) - - print(x[2, 3]) - - with pytest.raises(RuntimeError, match='declared after'): - y = ti.field(ti.f32, (2, 3)) - # ERROR: No new variable should be declared after Python-scope field access! - ''' - - -@ti.test(arch=ti.cpu) -def test_materialization_after_get_shape(): - pass - ''' - x = ti.field(ti.f32, (3, 4)) - - print(x.shape) - - with pytest.raises(RuntimeError, match='declared after'): - y = ti.field(ti.f32, (2, 3)) - # ERROR: No new variable should be declared after Python-scope field access! - ''' - - @ti.test(arch=ti.cpu) def test_materialize_callback(): x = ti.field(ti.f32, (3, 4))