Skip to content

Commit 4a7f4d5

Browse files
authored
Generate renamed to "Run Models" & Display running model name
1 parent 42b677c commit 4a7f4d5

1 file changed

Lines changed: 59 additions & 49 deletions

File tree

Vertical View - app.py

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
div[data-testid="stFormSubmitButton"] + div[data-testid="stVerticalBlock"] div[data-testid="stVerticalBlock"] div[data-testid="stHorizontalBlock"] {
2424
align-items: flex-end;
2525
}
26+
/* Style for the "Run Models" button */
27+
button[data-testid="stButton-primary"] {
28+
background-color: #FF0000 !important; /* Red background */
29+
color: white !important; /* White text */
30+
border-radius: 8px !important; /* Rounded corners */
31+
padding: 10px 20px !important; /* Adjust padding as needed */
32+
font-size: 16px !important; /* Adjust font size as needed */
33+
}
2634
</style>
2735
""", unsafe_allow_html=True)
2836

@@ -68,64 +76,66 @@ def get_models():
6876
st.session_state.model_count -= 1
6977
st.rerun()
7078

71-
run = st.button("Generate")
79+
run = st.button("Run Models", type="primary")
7280

7381
# Main display area
7482
st.title("Running LLMs in parallel")
7583

7684
if run and prompt.strip():
7785
model_inputs = [model for model in st.session_state.selected_models if model]
78-
responses = []
79-
86+
8087
if not model_inputs:
8188
st.warning("Please select at least one model to generate a response.")
8289
else:
83-
for model in model_inputs:
84-
start = time.time()
85-
try:
86-
response = requests.post(
87-
"http://localhost:11434/api/generate",
88-
json={"model": model, "prompt": prompt, "stream": False},
89-
).json()
90+
cols = st.columns(len(model_inputs))
91+
placeholders = [col.empty() for col in cols] # Create placeholders for each model's output
92+
93+
for i, model in enumerate(model_inputs):
94+
with placeholders[i].container():
95+
model_color = "blue" if i % 2 == 0 else "red"
96+
st.markdown(
97+
f"<h3 style='color:{model_color};'>{model}</h3>",
98+
unsafe_allow_html=True
99+
)
100+
101+
# Use st.spinner for the loading indicator
102+
with st.spinner(f"Running {model}..."):
103+
start = time.time()
104+
try:
105+
response = requests.post(
106+
"http://localhost:11434/api/generate",
107+
json={"model": model, "prompt": prompt, "stream": False},
108+
).json()
90109

91-
duration = round(time.time() - start, 2)
92-
content = response.get("response", "").strip()
93-
eval_count = response.get("eval_count", len(content.split()))
94-
eval_rate = response.get("eval_rate", round(eval_count / duration, 2))
110+
duration = round(time.time() - start, 2)
111+
content = response.get("response", "").strip()
112+
eval_count = response.get("eval_count", len(content.split()))
113+
eval_rate = response.get("eval_rate", round(eval_count / duration, 2))
95114

96-
responses.append({
97-
"model": model,
98-
"duration": duration,
99-
"eval_count": eval_count,
100-
"eval_rate": eval_rate,
101-
"response": content
102-
})
103-
except Exception as e:
104-
responses.append({
105-
"model": model,
106-
"duration": 0,
107-
"eval_count": 0,
108-
"eval_rate": 0,
109-
"response": f"Error: {e}"
110-
})
115+
# Clear the spinner and display the actual response
116+
placeholders[i].empty() # Clear the placeholder content including the spinner
117+
with placeholders[i].container(): # Redraw content
118+
st.markdown(
119+
f"<h3 style='color:{model_color};'>{model}</h3>",
120+
unsafe_allow_html=True
121+
)
122+
st.markdown(
123+
f"""
124+
<div style="background-color:#e6f0ff; padding:10px; border-radius:8px; margin-bottom:10px;">
125+
<b>Duration</b>: <span style="color:#3366cc;">{duration} secs</span><br>
126+
<b>Eval count</b>: <span style="color:green;">{eval_count} tokens</span><br>
127+
<b>Eval rate</b>: <span style="color:green;">{eval_rate} tokens/s</span>
128+
</div>
129+
""",
130+
unsafe_allow_html=True
131+
)
132+
st.write(content)
111133

112-
if responses:
113-
cols = st.columns(len(responses))
114-
for i, res in enumerate(responses):
115-
with cols[i]:
116-
model_color = "blue" if i % 2 == 0 else "red"
117-
st.markdown(
118-
f"<h3 style='color:{model_color};'>{res['model']}</h3>",
119-
unsafe_allow_html=True
120-
)
121-
st.markdown(
122-
f"""
123-
<div style="background-color:#e6f0ff; padding:10px; border-radius:8px; margin-bottom:10px;">
124-
<b>Duration</b>: <span style="color:#3366cc;">{res['duration']} secs</span><br>
125-
<b>Eval count</b>: <span style="color:green;">{res['eval_count']} tokens</span><br>
126-
<b>Eval rate</b>: <span style="color:green;">{res['eval_rate']} tokens/s</span>
127-
</div>
128-
""",
129-
unsafe_allow_html=True
130-
)
131-
st.write(res["response"])
134+
except Exception as e:
135+
placeholders[i].empty() # Clear the placeholder content including the spinner
136+
with placeholders[i].container(): # Redraw content
137+
st.markdown(
138+
f"<h3 style='color:{model_color};'>{model}</h3>",
139+
unsafe_allow_html=True
140+
)
141+
st.error(f"Error: {e}")

0 commit comments

Comments
 (0)