Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Nov 23, 2024
1 parent 7e4575c commit bbb922c
Show file tree
Hide file tree
Showing 25 changed files with 19 additions and 1,373 deletions.
13 changes: 7 additions & 6 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import braintaichi as bti

from brainpy import math as bm
from brainpy._src import connect, initialize as init
Expand Down Expand Up @@ -273,7 +274,7 @@ def _dense_on_post(
out_w[i, j] = old_w[i, j]


dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)
dense_on_post_prim = bti.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)


# @numba.njit(nogil=True, fastmath=True, parallel=False)
Expand Down Expand Up @@ -309,7 +310,7 @@ def _dense_on_pre(
out_w[i, j] = old_w[i, j]


dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)
dense_on_pre_prim = bti.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)

else:
dense_on_pre_prim = None
Expand Down Expand Up @@ -735,7 +736,7 @@ def _csr_on_pre_update(
out_w[i_syn] = old_w[i_syn]


csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)
csr_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)


@ti.kernel
Expand All @@ -759,7 +760,7 @@ def _coo_on_pre_update(
out_w[i_syn] = old_w[i_syn]


coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)
coo_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)


@ti.kernel
Expand All @@ -783,7 +784,7 @@ def _coo_on_post_update(
out_w[i_syn] = old_w[i_syn]


coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)
coo_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)


# @numba.njit(nogil=True, fastmath=True, parallel=False)
Expand Down Expand Up @@ -824,7 +825,7 @@ def _csc_on_post_update(
out_w[i_syn] = old_w[i_syn]


csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)
csc_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)


else:
Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
import brainpy as bp
import brainpy.math as bm

from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)


class TestLinear(parameterized.TestCase):
Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dnn/tests/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)


class Test_Conv(parameterized.TestCase):
Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dyn/projections/tests/test_STDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

bm.set_platform('cpu')

Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dyn/projections/tests/test_aligns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
import brainpy as bp
import brainpy.math as bm

from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
import brainpy as bp
import brainpy.math as bm
from brainpy._src.dynold.synapses import abstract_models
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)


class Test_Abstract_Synapse(parameterized.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import brainpy as bp
import brainpy.math as bm

from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

biological_models = [
bp.synapses.AMPA,
Expand Down
13 changes: 1 addition & 12 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,10 @@

from typing import Union, Tuple

import jax
import numpy as np
from jax import numpy as jnp
from jax.interpreters import ad
from jax.experimental.sparse import csr
from braintaichi import event_csrmm as bt_event_csrmm
from jax import numpy as jnp

from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (XLACustomOp, register_general_batching)
from brainpy._src.math.sparse.utils import csr_to_coo
from brainpy._src.math.defaults import float_

ti = import_taichi()

__all__ = [
'csrmm',
Expand Down
3 changes: 0 additions & 3 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
import jax
from braintaichi import event_csrmv as bt_event_csrmv

from brainpy._src.dependency_check import import_taichi

__all__ = [
'csrmv'
]

ti = import_taichi(error_if_not_found=False)


def csrmv(
Expand Down
4 changes: 0 additions & 4 deletions brainpy/_src/math/event/tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

import platform
force_test = False # turn on to force test on windows locally
Expand Down
2 changes: 0 additions & 2 deletions brainpy/_src/math/jitconn/event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
import jax
from braintaichi import jitc_event_mv_prob_homo, jitc_event_mv_prob_uniform, jitc_event_mv_prob_normal

from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.jitconn.matvec import (mv_prob_homo,
mv_prob_uniform,
mv_prob_normal)

ti = import_taichi(error_if_not_found=False)

__all__ = [
'event_mv_prob_homo',
Expand Down
Loading

0 comments on commit bbb922c

Please sign in to comment.