Skip to content

Commit

Permalink
Override eval method for a branch
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc committed Feb 1, 2025
1 parent e53ff3d commit 2c5aef0
Showing 1 changed file with 51 additions and 10 deletions.
61 changes: 51 additions & 10 deletions frontend/catalyst/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
This submodule defines a utility for converting plxpr into Catalyst jaxpr.
"""
# pylint: disable=protected-access
from copy import copy
from functools import partial
from typing import Callable
from typing import Callable, Sequence

import jax
import jax.core
import pennylane as qml
from jax.extend.linear_util import wrap_init
from pennylane.capture import PlxprInterpreter, disable, enable, enabled, qnode_prim
from pennylane.capture.primitives import cond_prim
from pennylane.capture.primitives import (
AbstractMeasurement,
AbstractOperator,
cond_prim,
)

from catalyst.device import (
extract_backend_info,
Expand Down Expand Up @@ -320,6 +323,50 @@ def cleanup(self):
self.qreg = qinsert_p.bind(self.qreg, orig_wire, wire)
self.stateref = None

def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
"""Evaluate a jaxpr.
Args:
jaxpr (jax.core.Jaxpr): the jaxpr to evaluate
consts (list[TensorLike]): the constant variables for the jaxpr
*args (tuple[TensorLike]): The arguments for the jaxpr.
Returns:
list[TensorLike]: the results of the execution.
"""
self._env = {}
self.setup()

for const, constvar in zip(consts, jaxpr.constvars, strict=True):
self._env[constvar] = const

for eqn in jaxpr.eqns:

custom_handler = self._primitive_registrations.get(eqn.primitive, None)
if custom_handler:
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)

Check warning on line 349 in frontend/catalyst/from_plxpr.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/from_plxpr.py#L348-L349

Added lines #L348 - L349 were not covered by tests
elif isinstance(eqn.outvars[0].aval, AbstractOperator):
outvals = self.interpret_operation_eqn(eqn)
elif isinstance(eqn.outvars[0].aval, AbstractMeasurement):
outvals = self.interpret_measurement_eqn(eqn)

Check warning on line 353 in frontend/catalyst/from_plxpr.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/from_plxpr.py#L353

Added line #L353 was not covered by tests
else:
invals = [self.read(invar) for invar in eqn.invars]
outvals = eqn.primitive.bind(*invals, **eqn.params)

Check warning on line 356 in frontend/catalyst/from_plxpr.py

View check run for this annotation

Codecov / codecov/patch

frontend/catalyst/from_plxpr.py#L355-L356

Added lines #L355 - L356 were not covered by tests

if not eqn.primitive.multiple_results:
outvals = [outvals]
for outvar, outval in zip(eqn.outvars, outvals, strict=True):
self._env[outvar] = outval

outvals = [qalloc_p.bind(len(self._device.wires))]

self.cleanup()
self._env = {}

return outvals


@QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive)
def _(self, *invals, n_wires):
Expand Down Expand Up @@ -362,15 +409,9 @@ def convert_branch_from_plxpr_to_jaxpr():
# Get the input qreg var.
in_qreg_var = converted_jaxpr_branch.constvars[0]

# Create an output qreg var as a copy of the input one
out_qreg_var = copy(in_qreg_var)

# Overwrite the input vars with the input qreg var
converted_jaxpr_branch = converted_jaxpr_branch.replace(invars=[in_qreg_var])

# Overwrite the output vars with the output qreg var
converted_jaxpr_branch = converted_jaxpr_branch.replace(outvars=[out_qreg_var])

# Remove the qreg var from the constants
constvars = converted_jaxpr_branch.constvars
converted_jaxpr_branch = converted_jaxpr_branch.replace(constvars=constvars[1:])
Expand All @@ -381,7 +422,7 @@ def convert_branch_from_plxpr_to_jaxpr():
if num_eqns > 1:
last_eqn = num_eqns - 1
converted_jaxpr_branch.eqns[last_eqn] = converted_jaxpr_branch.eqns[last_eqn].replace(
outvars=[out_qreg_var]
outvars=converted_jaxpr_branch.outvars
)

return converted_jaxpr_branch
Expand Down

0 comments on commit 2c5aef0

Please sign in to comment.