Skip to content

Commit

Permalink
move function definition to cpp file (#4522)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucafedeli88 authored Dec 15, 2023
1 parent 7233e4b commit 50e04bf
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 47 deletions.
47 changes: 1 addition & 46 deletions Source/AcceleratorLattice/LatticeElementFinder.H
Original file line number Diff line number Diff line change
Expand Up @@ -89,52 +89,7 @@ struct LatticeElementFinder
*/
void setup_lattice_indices (amrex::Gpu::DeviceVector<amrex::ParticleReal> const & zs,
amrex::Gpu::DeviceVector<amrex::ParticleReal> const & ze,
amrex::Gpu::DeviceVector<int> & indices)
{

using namespace amrex::literals;

const auto nelements = static_cast<int>(zs.size());
amrex::ParticleReal const * zs_arr = zs.data();
amrex::ParticleReal const * ze_arr = ze.data();
int * indices_arr = indices.data();

amrex::Real const zmin = m_zmin;
amrex::Real const dz = m_dz;

amrex::ParticleReal const gamma_boost = m_gamma_boost;
amrex::ParticleReal const uz_boost = m_uz_boost;
amrex::Real const time = m_time;

amrex::ParallelFor( m_nz,
[=] AMREX_GPU_DEVICE (int iz) {

// Get the location of the grid node
amrex::Real z_node = zmin + iz*dz;

if (gamma_boost > 1._prt) {
// Transform to lab frame
z_node = gamma_boost*z_node + uz_boost*time;
}

// Find the index to the element that is closest to the grid cell.
// For now, this assumes that there is no overlap among elements of the same type.
for (int ie = 0 ; ie < nelements ; ie++) {
// Find the mid points between element ie and the ones before and after it.
// The first and last element need special handling.
const amrex::ParticleReal zcenter_left = (ie == 0)?
(std::numeric_limits<amrex::ParticleReal>::lowest()) : (0.5_prt*(ze_arr[ie-1] + zs_arr[ie]));
const amrex::ParticleReal zcenter_right = (ie < nelements - 1)?
(0.5_prt*(ze_arr[ie] + zs_arr[ie+1])) : (std::numeric_limits<amrex::ParticleReal>::max());
if (zcenter_left <= z_node && z_node < zcenter_right) {
indices_arr[iz] = ie;
}

}
}
);
}

amrex::Gpu::DeviceVector<int> & indices);
};

/**
Expand Down
48 changes: 47 additions & 1 deletion Source/AcceleratorLattice/LatticeElementFinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ LatticeElementFinder::GetFinderDeviceInstance (WarpXParIter const& a_pti, int co
return result;
}


void
LatticeElementFinderDevice::InitLatticeElementFinderDevice (WarpXParIter const& a_pti, int const a_offset,
AcceleratorLattice const& accelerator_lattice,
Expand Down Expand Up @@ -121,3 +120,50 @@ LatticeElementFinderDevice::InitLatticeElementFinderDevice (WarpXParIter const&
}

}

void
LatticeElementFinder::setup_lattice_indices (amrex::Gpu::DeviceVector<amrex::ParticleReal> const & zs,
amrex::Gpu::DeviceVector<amrex::ParticleReal> const & ze,
amrex::Gpu::DeviceVector<int> & indices)
{

using namespace amrex::literals;

const auto nelements = static_cast<int>(zs.size());
amrex::ParticleReal const * zs_arr = zs.data();
amrex::ParticleReal const * ze_arr = ze.data();
int * indices_arr = indices.data();

amrex::Real const zmin = m_zmin;
amrex::Real const dz = m_dz;

amrex::ParticleReal const gamma_boost = m_gamma_boost;
amrex::ParticleReal const uz_boost = m_uz_boost;
amrex::Real const time = m_time;

amrex::ParallelFor( m_nz,
[=] AMREX_GPU_DEVICE (int iz) {

// Get the location of the grid node
amrex::Real z_node = zmin + iz*dz;

if (gamma_boost > 1._prt) {
// Transform to lab frame
z_node = gamma_boost*z_node + uz_boost*time;
}

// Find the index to the element that is closest to the grid cell.
// For now, this assumes that there is no overlap among elements of the same type.
for (int ie = 0 ; ie < nelements ; ie++) {
// Find the mid points between element ie and the ones before and after it.
// The first and last element need special handling.
const amrex::ParticleReal zcenter_left = (ie == 0)?
(std::numeric_limits<amrex::ParticleReal>::lowest()) : (0.5_prt*(ze_arr[ie-1] + zs_arr[ie]));
const amrex::ParticleReal zcenter_right = (ie < nelements - 1)?
(0.5_prt*(ze_arr[ie] + zs_arr[ie+1])) : (std::numeric_limits<amrex::ParticleReal>::max());
if (zcenter_left <= z_node && z_node < zcenter_right) {
indices_arr[iz] = ie;
}
}
});
}

0 comments on commit 50e04bf

Please sign in to comment.