Skip to content

Commit

Permalink
[𝘀𝗽𝗿] initial version
Browse files Browse the repository at this point in the history
Created using spr 1.3.4
  • Loading branch information
boomanaiden154 committed Jan 16, 2025
1 parent 3ddfad9 commit 2f7027b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
25 changes: 10 additions & 15 deletions compiler_opt/es/blackbox_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,18 @@ def get_results(
def set_baseline(self) -> None:
raise NotImplementedError()

@abc.abstractmethod
def get_rewards(
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
raise NotImplementedError()
rewards = [None] * len(results)

for i in range(len(results)):
if not results[i].exception():
rewards[i] = results[i].result()
else:
logging.info('Error retrieving result from future: %s',
str(results[i].exception()))

return rewards


@gin.configurable
Expand Down Expand Up @@ -95,16 +103,3 @@ def get_results(

def set_baseline(self) -> None:
pass

def get_rewards(
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
rewards = [None] * len(results)

for i in range(len(results)):
if not results[i].exception():
rewards[i] = results[i].result()
else:
logging.info('Error retrieving result from future: %s',
str(results[i].exception()))

return rewards
2 changes: 1 addition & 1 deletion compiler_opt/es/blackbox_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_sampling_get_results(self):
self.assertSequenceAlmostEqual([result.result() for result in results],
[1.0, 1.0, 1.0])

def test_sampling_get_rewards(self):
def test_get_rewards(self):
f1 = concurrent.futures.Future()
f1.set_exception(None)
f2 = concurrent.futures.Future()
Expand Down

0 comments on commit 2f7027b

Please sign in to comment.