Skip to content

Commit

Permalink
cleaned test files; added argv and stdin to initial prompt, if anythi…
Browse files Browse the repository at this point in the history
…ngs is there
  • Loading branch information
stephenfreund committed Mar 20, 2024
1 parent fd1b23e commit 1555289
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/chatdbg/chatdbg_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import sys
import textwrap
import traceback
from io import StringIO
from io import StringIO, TextIOWrapper
from pathlib import Path
from pprint import pprint

import IPython
import llm_utils
from traitlets import TraitError

from chatdbg.ipdb_util.capture import CaptureInput

from .assistant.assistant import Assistant
from .ipdb_util.config import Chat
from .ipdb_util.logging import ChatDBGLog, CopyingTextIOWrapper
Expand Down Expand Up @@ -84,11 +86,13 @@ def __init__(self, *args, **kwargs):
self._assistant = None
self._history = []
self._error_specific_prompt = ""

global chatdbg_config
if chatdbg_config == None:
chatdbg_config = Chat()

sys.stdin = CaptureInput(sys.stdin)

# Only use flow when we are in jupyter or using stdin in ipython. In both
# cases, there will be no python file at the start of argv after the
# ipython commands.
Expand Down Expand Up @@ -526,6 +530,11 @@ def _build_prompt(self, arg, conversing):
if not conversing:
stack_dump = f"The program has this stack trace:\n```\n{self.format_stack_trace()}\n```\n\n"
prompt = "\n" + stack_dump + self._error_specific_prompt
if len(sys.argv) > 1:
prompt += f"\nThese were the command line options:\n```\n{' '.join(sys.argv)}\n```\n"
input = sys.stdin.get_captured_input()
if len(input) > 0:
prompt += f"\nThis was the program's input :\n```\n{input}```\n"

if len(self._history) > 0:
hist = textwrap.indent(self._capture_onecmd("hist"), "")
Expand Down Expand Up @@ -578,6 +587,8 @@ def do_mark(self, arg):
else:
self._log.add_mark(arg)



def do_config(self, arg):
args = arg.split()
if len(args) == 0:
Expand Down
31 changes: 31 additions & 0 deletions src/chatdbg/ipdb_util/capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from io import StringIO, TextIOWrapper

class CaptureInput:
def __init__(self, input_stream):
input_stream = TextIOWrapper(input_stream.buffer, encoding='utf-8', newline='')

self.original_input = input_stream
self.capture_buffer = StringIO()
self.original_readline = input_stream.buffer.raw.readline

def custom_readline(*args, **kwargs):
input_data = self.original_readline(*args, **kwargs)
self.capture_buffer.write(input_data.decode())
return input_data

input_stream.buffer.raw.readline = custom_readline

def readline(self, *args, **kwargs):
input_data = self.original_input.readline(*args, **kwargs)
self.capture_buffer.write(input_data)
self.capture_buffer.flush()
return input_data

def read(self, *args, **kwargs):
input_data = self.original_input.read(*args, **kwargs)
self.capture_buffer.write(input_data)
self.capture_buffer.flush()
return input_data

def get_captured_input(self):
return self.capture_buffer.getvalue()
21 changes: 21 additions & 0 deletions test/python/bootstrap2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datascience import *
from ds101 import *

def make_marble_sample():
table = Table().read_table('marble-sample.csv')
return table.column('color')

def proportion_blue(sample):
return np.count_nonzero(sample == 'B') / len(sample)

def resampled_stats(observed_marbles, num_trials):
stats = bootstrap_statistic(observed_marbles,
proportion_blue,
num_trials)
assert len(stats) == num_trials
return stats

observed_marbles = make_marble_sample()
stats = resampled_stats(observed_marbles, 5)

assert np.isclose(np.mean(stats), 0.7)
34 changes: 34 additions & 0 deletions test/python/ds101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
from datascience import *

# fake library to hide identities...

def bootstrap_statistic(observed_sample, compute_statistic, num_trials):
"""
Creates num_trials resamples of the initial sample.
Returns an array of the provided statistic for those samples.
* observed_sample: the initial sample, as an array.
* compute_statistic: a function that takes a sample as
an array and returns the statistic for that
sample.
* num_trials: the number of bootstrap samples to create.
"""

# Check that observed_sample is an array!
if not isinstance(observed_sample, np.ndarray):
raise ValueError('The first parameter to bootstrap_statistic must be a sample represented as an array, not a value of type ' + str(type(observed_sample).__name__))

statistics = make_array()

for i in np.arange(0, num_trials):
#Key: in bootstrapping we must always sample with replacement
simulated_resample = np.random.choice(observed_sample, len(observed_sample))

resample_statistic = compute_statistic(simulated_resample)
statistics = np.append(statistics, resample_statistic)

return statistics

0 comments on commit 1555289

Please sign in to comment.