Skip to content

Commit

Permalink
preparing tile-mapped index generation and tile mapping matching; wor…
Browse files Browse the repository at this point in the history
…k in progress
  • Loading branch information
mihaipgc committed Oct 23, 2023
1 parent 43cd3cd commit c81b013
Showing 1 changed file with 106 additions and 32 deletions.
138 changes: 106 additions & 32 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import jax.numpy as jnp
import os.path as osp
from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated
from tessellate_ipu.lax.tile_lax_binary import add_inplace_p, div_inplace_p, mul_inplace_p, rem_inplace_p
from tessellate_ipu.lax.tile_lax_unary import tile_copy
from functools import partial
from icecream import ic
from tqdm import tqdm
Expand Down Expand Up @@ -232,6 +234,7 @@ def compute_diff_jk(dm, mol, nbatch, tolerance, backend):

all_eris = []
all_indices = []
all_tiles = []
np.random.seed(42) # is this needed?

NUM_TILES = 1472
Expand All @@ -245,58 +248,127 @@ def compute_diff_jk(dm, mol, nbatch, tolerance, backend):
tile_floats = tile_put_replicated(input_floats, tiles)
tile_ints = tile_put_replicated(input_ints, tiles)

for i, (size, count) in enumerate(zip(sizes, counts)):
glen, nf = shapes[i]
# For each shell defined by its offset index, compute ERI value tensor of size `size`
for zip_counter, (size, count) in enumerate(zip(sizes, counts)):
glen, nf = shapes[zip_counter]
chunk_size = num_tiles * num_threads
num_full_batches = count//chunk_size

tiles = tuple((np.arange(num_tiles*num_threads)%(num_tiles)+1).tolist())
tile_g = tile_put_replicated(jnp.empty(min(int(glen), 3888)+1), tiles)
tile_idx = tile_put_replicated(jnp.empty(max(256, min(int(nf*3), 3888)+1), dtype=jnp.int32), tiles)
tile_buf = tile_put_replicated(jnp.empty(1080*4+1), tiles)
integral_size = tile_put_replicated(jnp.array(size, dtype=jnp.uint32), tiles)
tiles = np.arange(num_tiles*num_threads, dtype=np.uint32)%(num_tiles)+1
tile_g = tile_put_replicated(jnp.empty(min(int(glen), 3888)+1), tuple(tiles.tolist()))
tile_idx = tile_put_replicated(jnp.empty(max(256, min(int(nf*3), 3888)+1), dtype=jnp.int32), tuple(tiles.tolist()))
tile_buf = tile_put_replicated(jnp.empty(1080*4+1), tuple(tiles.tolist()))
integral_size = tile_put_replicated(jnp.array(size, dtype=jnp.uint32), tuple(tiles.tolist()))

def batched_compute(start, stop, chunk_size, tiles):
assert (stop-start) < chunk_size or (stop-start) % chunk_size == 0
num_batches = max(1, (stop-start)//chunk_size)
idx = np.array(input_ijkl[i][start:stop]).reshape(-1, num_batches, 4)
idx = np.array(input_ijkl[zip_counter][start:stop]).reshape(-1, num_batches, 4)
tile_mappings = tiles
tiles = tuple(tiles.tolist())
out , _, _, _= tile_map(int2e_sph_forloop,
tile_floats[:len(tiles)],
tile_ints[:len(tiles)],
tile_put_sharded(idx, tiles),
tile_put_sharded(jnp.empty((len(tiles), num_batches, size)), tiles),
tile_put_sharded(jnp.empty((len(tiles), num_batches, size)), tiles), # this is where the output is mapped
tile_g[:len(tiles)],
tile_idx[:len(tiles)],
tile_buf[:len(tiles)],
tile_put_replicated(jnp.array(num_batches, dtype=jnp.uint32), tiles),
integral_size[:len(tiles)])
return out.array.reshape(-1, size), idx.reshape(-1, 4)
return out.array.reshape(-1, size), idx.reshape(-1, 4), tile_mappings.reshape(-1, 1)

if num_full_batches > 0: f_out, f_idx = batched_compute(0, num_full_batches*chunk_size, chunk_size, tiles)
else: f_out, f_idx = np.array([]).reshape(0, size), np.array([]).reshape(0, 4)
if num_full_batches > 0: f_out, f_idx, f_til = batched_compute(0, num_full_batches*chunk_size, chunk_size, tiles)
else: f_out, f_idx, f_til = np.array([]).reshape(0, size), np.array([]).reshape(0, 4), np.array([]).reshape(0, 1)

tiles = tuple((np.arange(count-num_full_batches*chunk_size)%(num_tiles)+1).tolist())
out, idx = batched_compute(num_full_batches*chunk_size, count, chunk_size, tiles)
tiles = np.arange(count-num_full_batches*chunk_size, dtype=np.uint32)%(num_tiles)+1
out, idx, til = batched_compute(num_full_batches*chunk_size, count, chunk_size, tiles)

all_eris.append(jnp.concatenate([f_out, out]))
all_indices.append(np.concatenate([f_idx, idx]).astype(np.uint8))
all_tiles.append(np.concatenate([f_til, til]).astype(np.uint32))

print('[a.shape for a in all_eris]', [a.shape for a in all_eris])
print('[a.shape for a in all_indices]', [a.shape for a in all_indices])

total_diff_JK = 0
for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)):
for zip_counter, (eri, idx, til) in enumerate(zip(all_eris, all_indices, all_tiles)):
num_shells, shell_size = eri.shape # save original tensor shape

remainder = (eri.shape[0]) % (nbatch)

# pad tensors; unused for nipu==batches==1
if remainder != 0:
eri = jnp.pad(eri, ((0, nbatch-remainder), (0, 0)))
idx = jnp.pad(idx, ((0, nbatch-remainder), (0, 0)))

nonzero_distinct_ERI = eri.reshape(nbatch, -1)
idx = np.pad(idx, ((0, nbatch-remainder), (0, 0)))
til = np.pad(til, ((0, nbatch-remainder), (0, 0))) # may overload tile0

def gen_shell_idx(idx_sh, dx, x0):
# Unpack values
_di, _dj, _dk, _dl = dx
_i0, _j0, _k0, _l0 = x0

# Compute the indices
ind_i = (idx_sh ) % _di + _i0
ind_j = (idx_sh // (_di) ) % _dj + _j0
ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0

# Update the array with the computed indices
# return jnp.stack([ind_i.reshape(-1), ind_j.reshape(-1), ind_k.reshape(-1), ind_l.reshape(-1)], axis=1)
return ind_i.reshape(-1, 1), ind_j.reshape(-1, 1), ind_k.reshape(-1, 1), ind_l.reshape(-1, 1)

nonzero_distinct_ERI = eri.reshape(nbatch, -1) # already mapped by now
print('len(til)', len(til))
print('idx.shape', idx.shape)

nonzero_indices = idx.reshape(nbatch, -1, 4)
# nonzero_indices = tile_put_sharded(idx, til) # shard indices to match tiles of eri shells; should be .reshape(nbatch, -1, 4) but is (-1, 4)
print('nonzero_distinct_ERI.shape', nonzero_distinct_ERI.shape)
print('nonzero_indices.shape', nonzero_indices.shape)

sh_til = np.repeat(til, repeats=shell_size, axis=-1).reshape(nbatch, -1) # repeat tile ids to keep entire shell on the same tile
print('sh_til.shape', sh_til.shape)

# @partial(jax.jit, backend="ipu")
# def gen_shell_idx(idx_sh, dx, x0):
# # Unpack values
# _di, _dj, _dk, _dl = [_dx for _dx in dx]
# _i0, _j0, _k0, _l0 = [_x0 for _x0 in x0]

# # Compute the indices

# # ind_i = (idx_sh ) % _di + _i0
# ind_i = tile_copy(idx_sh)
# ind_i = tile_map(rem_inplace_p, ind_i, _di)
# ind_i = tile_map(add_inplace_p, ind_i, _i0)

# # ind_j = (idx_sh // (_di) ) % _dj + _j0
# tmp = _di # reuse _di
# ind_j = tile_copy(idx_sh)
# ind_j = tile_map(div_inplace_p, ind_j, tmp)
# ind_j = tile_map(rem_inplace_p, ind_j, _dj)
# ind_j = tile_map(add_inplace_p, ind_j, _j0)

# # ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
# tmp = tile_map(mul_inplace_p, tmp, _dj) # reuse _di
# ind_k = tile_copy(idx_sh)
# ind_k = tile_map(div_inplace_p, ind_k, tmp)
# ind_k = tile_map(rem_inplace_p, ind_k, _dk)
# ind_k = tile_map(add_inplace_p, ind_k, _k0)

# # ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0
# tmp = tile_map(mul_inplace_p, tmp, _dk) # reuse _di
# ind_l = tile_copy(idx_sh)
# ind_l = tile_map(div_inplace_p, ind_l, tmp)
# ind_l = tile_map(rem_inplace_p, ind_l, _dl)
# ind_l = tile_map(add_inplace_p, ind_l, _l0)

# # Update the array with the computed indices
# # return jnp.concatenate([ind_i, ind_j, ind_k, ind_l], axis=1)
# return ind_i, ind_j, ind_k, ind_l



dm = dm.reshape(-1)
diff_JK = jnp.zeros(dm.shape)
Expand All @@ -311,18 +383,18 @@ def foreach_batch(i, vals):
_di, _dj, _dk, _dl = [(ao_loc[z+1] - ao_loc[z]).reshape(-1, 1) for z in [_i, _j, _k, _l]]
_i0, _j0, _k0, _l0 = [ao_loc[z].reshape(-1, 1) for z in [_i, _j, _k, _l]]

def gen_shell_idx(idx_sh):
# Compute the indices
ind_i = (idx_sh ) % _di + _i0
ind_j = (idx_sh // (_di) ) % _dj + _j0
ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0

# Update the array with the computed indices
return jnp.stack([ind_i.reshape(-1), ind_j.reshape(-1), ind_k.reshape(-1), ind_l.reshape(-1)], axis=1)
print('_i.shape', _i.shape)
print('_di.shape', _di.shape)
print('_i0.shape', _i0.shape)

eris = nonzero_distinct_ERI[i].reshape(-1)
indices = gen_shell_idx(jnp.arange((eris.shape[0])).reshape(-1, shell_size))
idx_sh = jnp.arange((eris.shape[0])).reshape(-1, shell_size)

# indices = gen_shell_idx(jnp.arange((eris.shape[0])).reshape(-1, shell_size), (_di, _dj, _dk, _dl), (_i0, _j0, _k0, _l0))
# indices = jnp.concatenate(gen_shell_idx(idx_sh, (_di, _dj, _dk, _dl), (_i0, _j0, _k0, _l0)), axis=1)
indices = jnp.concatenate(gen_shell_idx(idx_sh, (_di, _dj, _dk, _dl), (_i0, _j0, _k0, _l0)), axis=1)
print('eris.shape', eris.shape)
print('indices.shape', indices.shape)

# compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly
drep = num_repetitions_fast_4d(indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3], xnp=jnp, dtype=jnp.uint32)
Expand All @@ -343,7 +415,7 @@ def gen_shell_idx(idx_sh):

input_N = tile_put_replicated(jnp.array(N, dtype=jnp.uint32), input_tiles)
input_start = tile_put_replicated(jnp.array(0, dtype=jnp.uint32), input_tiles)
input_stop = tile_put_replicated(jnp.array(si.shape[1], dtype=jnp.uint32), input_tiles)
input_stop = tile_put_replicated(jnp.array(input_size//total_threads+1, dtype=jnp.uint32), input_tiles)

si = tile_put_sharded(si, input_tiles)
sj = tile_put_sharded(sj, input_tiles)
Expand All @@ -359,7 +431,7 @@ def foreach_symmetry(sym, vals):
if backend == "cpu": dm_indices = cpu_ijkl(indices, sym+is_K_matrix*8, indices_func)
else: dm_indices = tile_map(compute_indices, si, sj, sk, sl, tile_put_replicated(jnp.array(sym+is_K_matrix*8, dtype=jnp.uint32), input_tiles) , input_N, input_start, input_stop).array.reshape(-1)[:input_size]
dm_values = jnp.take(dm, dm_indices, axis=0)

dm_values = dm_values.at[:].mul( eris ) # this is prod, but re-use variable for inplace update.

if backend == "cpu": ss_indices = cpu_ijkl(indices, sym+8+is_K_matrix*8, indices_func)
Expand All @@ -372,9 +444,10 @@ def foreach_symmetry(sym, vals):

return (diff_JK, nonzero_indices, ao_loc)

batches = nonzero_indices.shape[0] # before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap
# batches = nonzero_indices.shape[0] # before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap
# batches = len(nonzero_indices)

diff_JK, _, _ = jax.lax.fori_loop(0, batches, foreach_batch, (diff_JK, nonzero_indices, ao_loc))
diff_JK, _, _ = jax.lax.fori_loop(0, nbatch, foreach_batch, (diff_JK, nonzero_indices, ao_loc))

total_diff_JK += diff_JK

Expand Down Expand Up @@ -403,6 +476,7 @@ def foreach_symmetry(sym, vals):

start = time.time()

# mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm)), basis=args.basis)
#mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm))) # sto-3g by default
# mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(1) for j in range(2)), basis="sto3g")
#mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm)), basis="sto3g")
Expand Down

0 comments on commit c81b013

Please sign in to comment.