|
| 1 | +import asyncio |
| 2 | +import os |
| 3 | +import shutil |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | +import writer as wf |
| 7 | +import writer.ai |
| 8 | +from dotenv import load_dotenv |
| 9 | + |
| 10 | +from prompts import (generate_negative_concatenation_prompt, |
| 11 | + generate_negative_impacts_prompt, |
| 12 | + generate_negative_stock_selection_prompt, |
| 13 | + generate_positive_concatenation_prompt, |
| 14 | + generate_positive_impacts_prompt, |
| 15 | + generate_positive_sector_weighting_prompt, |
| 16 | + generate_positive_stock_selection_prompt, |
| 17 | + generate_rebalance_recommendation_prompt) |
| 18 | + |
| 19 | +load_dotenv() |
| 20 | + |
| 21 | +pd.options.mode.chained_assignment = None |
| 22 | + |
| 23 | + |
| 24 | +def handle_file_on_change(state, payload): |
| 25 | + clear_results(state) |
| 26 | + _save_file(state, payload[0]) |
| 27 | + |
| 28 | + file_extension = state["file"]["name"].split(".")[-1].lower() |
| 29 | + |
| 30 | + match file_extension: |
| 31 | + case "xlsx" | "xls": |
| 32 | + df = _read_excel(state) |
| 33 | + case "pdf": |
| 34 | + text_data = _read_pdf(state) |
| 35 | + case _: |
| 36 | + state["processing-message"] = "Unsupported file type" |
| 37 | + return |
| 38 | + |
| 39 | + state["processing-message"] = "Analyzing positive and negative impacts..." |
| 40 | + prompt = analyze_data( |
| 41 | + df if file_extension in ["xlsx", "xls"] else text_data |
| 42 | + ) |
| 43 | + |
| 44 | + state["processing-message"] = "Generating rebalancing recommendation..." |
| 45 | + state["analysis-result"] = "" |
| 46 | + |
| 47 | + # Using Palmyra-Fin model |
| 48 | + # for chunk in writer.ai.stream_complete(prompt, {"model": "palmyra-fin-32k", "max_tokens": 2048, "temperature": 0.7}): |
| 49 | + # state["analysis-result"] += chunk |
| 50 | + |
| 51 | + # Using Palmyra X 004 model |
| 52 | + conversation = writer.ai.Conversation([{"role": "user", "content": prompt}], {"model": "palmyra-x-004", "max_tokens": 2048, "temperature": 0.7}) |
| 53 | + for chunk in conversation.stream_complete(): |
| 54 | + if chunk.get("content"): |
| 55 | + state["analysis-result"] += chunk.get("content") |
| 56 | + |
| 57 | + state["visual_block_visible"] = True |
| 58 | + state["processing-message"] = "" |
| 59 | + |
| 60 | + |
| 61 | +def clear_results(state): |
| 62 | + state["analysis-result"] = "Your analysis will appear here." |
| 63 | + _delete_all_files(state) |
| 64 | + state["visual_block_visible"] = False |
| 65 | + |
| 66 | + |
| 67 | +def _save_file(state, file): |
| 68 | + name = file.get("name") |
| 69 | + state["file"]["name"] = name |
| 70 | + state["file"]["file_path"] = f"data/{name}" |
| 71 | + state["processing-message"] = f"File {name} saved." |
| 72 | + file_data = file.get("data") |
| 73 | + with open(f"data/{name}", "wb") as file_handle: |
| 74 | + file_handle.write(file_data) |
| 75 | + |
| 76 | + |
| 77 | +def _delete_all_files(state): |
| 78 | + directory = "data" |
| 79 | + |
| 80 | + if os.path.exists(directory): |
| 81 | + shutil.rmtree(directory) |
| 82 | + |
| 83 | + os.makedirs(directory) |
| 84 | + |
| 85 | + state["file"]["name"] = "" |
| 86 | + state["file"]["file_path"] = "" |
| 87 | + state["processing-message"] = "All files have been deleted." |
| 88 | + |
| 89 | + |
| 90 | +def _read_excel(state): |
| 91 | + data = pd.read_excel(state["file"]["file_path"]) |
| 92 | + df = pd.DataFrame(data) |
| 93 | + return df |
| 94 | + |
| 95 | + |
| 96 | +def _read_pdf(state): |
| 97 | + from PyPDF2 import PdfReader |
| 98 | + |
| 99 | + reader = PdfReader(state["file"]["file_path"]) |
| 100 | + text = "" |
| 101 | + for page in reader.pages: |
| 102 | + text += page.extract_text() |
| 103 | + return text |
| 104 | + |
| 105 | +# Create async tasks for each prompt and then generate the final prompt |
| 106 | +async def gather_results_and_generate_rebalance_prompt(data: str): |
| 107 | + async def complete_async(prompt): |
| 108 | + return await asyncio.to_thread(writer.ai.complete, prompt, {"model": "palmyra-fin-32k", "max_tokens": 3048, "temperature": 0.7}) |
| 109 | + |
| 110 | + positive_stock_selection_task = complete_async( |
| 111 | + generate_positive_stock_selection_prompt(data) |
| 112 | + ) |
| 113 | + |
| 114 | + positive_sector_weighting_task = complete_async( |
| 115 | + generate_positive_sector_weighting_prompt(data) |
| 116 | + ) |
| 117 | + |
| 118 | + positive_concatenation_task = complete_async( |
| 119 | + generate_positive_concatenation_prompt(data) |
| 120 | + ) |
| 121 | + |
| 122 | + positive_impacts_task = complete_async(generate_positive_impacts_prompt(data)) |
| 123 | + |
| 124 | + negative_stock_selection_task = complete_async( |
| 125 | + generate_negative_stock_selection_prompt(data) |
| 126 | + ) |
| 127 | + |
| 128 | + negative_concatenation_task = complete_async( |
| 129 | + generate_negative_concatenation_prompt(data) |
| 130 | + ) |
| 131 | + |
| 132 | + negative_impacts_task = complete_async(generate_negative_impacts_prompt(data)) |
| 133 | + |
| 134 | + ( |
| 135 | + positive_stock_selection, |
| 136 | + positive_sector_weighting, |
| 137 | + positive_concatenation, |
| 138 | + positive_impacts, |
| 139 | + negative_stock_selection, |
| 140 | + negative_concatenation, |
| 141 | + negative_impacts, |
| 142 | + ) = await asyncio.gather( |
| 143 | + positive_stock_selection_task, |
| 144 | + positive_sector_weighting_task, |
| 145 | + positive_concatenation_task, |
| 146 | + positive_impacts_task, |
| 147 | + negative_stock_selection_task, |
| 148 | + negative_concatenation_task, |
| 149 | + negative_impacts_task, |
| 150 | + ) |
| 151 | + |
| 152 | + final_prompt = generate_rebalance_recommendation_prompt( |
| 153 | + positive_stock_selection, |
| 154 | + positive_sector_weighting, |
| 155 | + positive_concatenation, |
| 156 | + positive_impacts, |
| 157 | + negative_stock_selection, |
| 158 | + negative_concatenation, |
| 159 | + negative_impacts, |
| 160 | + ) |
| 161 | + |
| 162 | + return final_prompt |
| 163 | + |
| 164 | + |
| 165 | +def analyze_data(data): |
| 166 | + if isinstance(data, pd.DataFrame): |
| 167 | + data_str = data.to_string() |
| 168 | + else: |
| 169 | + data_str = data |
| 170 | + |
| 171 | + return asyncio.run( |
| 172 | + gather_results_and_generate_rebalance_prompt(data_str) |
| 173 | + ) |
| 174 | + |
| 175 | + |
| 176 | +def handle_file_download(state): |
| 177 | + analysis_result = state["analysis-result"] |
| 178 | + |
| 179 | + if analysis_result: |
| 180 | + with open("data/analysis_result.txt", "w") as file_handle: |
| 181 | + file_handle.write(analysis_result) |
| 182 | + |
| 183 | + file_data = wf.pack_file("data/analysis_result.txt", "text/plain") |
| 184 | + state.file_download(file_data, "analysis_result.txt") |
| 185 | + |
| 186 | + |
| 187 | +initial_state = wf.init_state( |
| 188 | + { |
| 189 | + "image-path": "static/writer_logo.png", |
| 190 | + "app": {"title": "Recommendations for Rebalancing Portfolio"}, |
| 191 | + "file": {"name": "", "file_path": ""}, |
| 192 | + "analysis-result": "Your recommendations will appear here.", |
| 193 | + "processing-message": "", |
| 194 | + "visual_block_visible": False |
| 195 | + } |
| 196 | +) |
| 197 | + |
| 198 | + |
| 199 | +initial_state.import_stylesheet("style", "/static/custom.css?") |
0 commit comments