diff --git a/doccano_mini/app.py b/doccano_mini/app.py index 299e947..6750970 100644 --- a/doccano_mini/app.py +++ b/doccano_mini/app.py @@ -1,3 +1,5 @@ +import os + import streamlit as st from langchain.chains import LLMChain from langchain.llms import OpenAI @@ -47,6 +49,7 @@ def task_classification(task: TaskType): prompt.prefix = instruction st.header("Test") + api_key = st.text_input("Enter API key", value=os.environ.get("OPENAI_API_KEY", "")) col1, col2 = st.columns([3, 1]) text = col1.text_area(label="Please enter your text.", value="", height=300) @@ -59,7 +62,7 @@ def task_classification(task: TaskType): st.markdown(f"```\n{prompt.format(input=text)}\n```") if st.button("Predict"): - llm = OpenAI(model_name=model_name, temperature=temperature, top_p=top_p) # type:ignore + llm = OpenAI(model_name=model_name, temperature=temperature, top_p=top_p, openai_api_key=api_key) # type:ignore chain = LLMChain(llm=llm, prompt=prompt) response = chain.run(text) label = response.split(":")[1]