Skip to content

Commit

Permalink
Add multi_add interface
Browse files Browse the repository at this point in the history
  • Loading branch information
LuoYuanke committed Feb 27, 2018
1 parent 5faca3f commit c7c20bf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/ideep4py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def check_ndim(inputs, supported_ndim=(2, 4)):


def check_type(inputs):
if isinstance(inputs[0], numpy.ndarray) or \
isinstance(inputs[0], mdarray):
if isinstance(inputs[0], (numpy.ndarray, mdarray)):
_should_use_ideep = True

for x in inputs:
Expand Down Expand Up @@ -178,6 +177,12 @@ def tanh(x):
return y


def multi_add(xs):
xs = tuple([array(x) for x in xs])
y = basic_acc_sum(xs)
return y


def convolution2DParam(out_dims, dy, dx, sy, sx, ph, pw, pd, pr):
cp = conv2DParam()
cp.out_dims = intVector()
Expand Down
7 changes: 7 additions & 0 deletions python/ideep4py/tests/mm/test_acc_sum_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ def test_sum(self):
mx = ideep4py.basic_acc_sum((mx1, mx2, mx3, mx4))
# mx = numpy.asarray(mx)
numpy.testing.assert_allclose(mx, x, **self.check_options)

def test_multi_add(self):
mx1 = ideep4py.mdarray(self.x1)
mx2 = ideep4py.mdarray(self.x2)
x = self.x1 + self.x2 + self.x3 + self.x4
mx = ideep4py.multi_add((mx1, mx2, self.x3, self.x4))
numpy.testing.assert_allclose(mx, x, **self.check_options)

0 comments on commit c7c20bf

Please sign in to comment.