1+ import numpy as np
2+ from datetime import date
3+ from sklearn .model_selection import train_test_split
4+ from sklearn .ensemble import RandomForestClassifier
5+ from sklearn .pipeline import Pipeline , make_pipeline
6+ from pipeline_utils import NamedColumnTransformer
7+ import pandas as pd
8+ import scipy .stats as st
9+ import os
10+ import pickle
11+ from datetime import timedelta
12+
13+ import data
14+ import data_utils
15+
16+ import gp_features
17+ from gp_utils import *
18+
19+ # Define output directory
20+ save_dir = "baseline_data"
21+
22+ # data handle
23+ data_handle = data .Data ()
24+
25+ # define which features to include --> with TTM
26+ feature_collection = gp_features .GP_Feature_Collection (
27+ gp_features .GP_Features_affinity (),
28+ gp_features .GP_Features_TTM (),
29+ gp_features .GP_Features_Question (),
30+ gp_features .GP_Features_user ())
31+
32+ # parameters for suggested questions
33+ hour_threshold_suggested_answer = 24
34+ only_open_questions_suggestable = False
35+ filter_nan_asker_id = True
36+
37+ start_time = None # data_utils.make_datetime("01.01.2012 00:01")
38+ end_time = data_utils .make_datetime ("01.01.2016 00:01" ) # data_utils.make_datetime("01.03.2012 00:01")
39+
40+ all_feates_collector = list ()
41+
42+ n_candidates_collector = list ()
43+
44+ save_every = 10000
45+ q_a_pair_counter = 1
46+
47+ for i , event in enumerate (data_utils .all_answer_events_iterator (timedelta (days = 2 ), start_time = start_time , end_time = end_time )):
48+ if np .isnan (event .answerer_user_id ) or np .isnan (event .asker_user_id ):
49+ continue
50+
51+ if i % 100 == 0 :
52+ avg_candidates = np .mean (n_candidates_collector )
53+ print ("Preptraining at {}| on average {} candidates in the last {} suggested_question_events" .format (event .answer_date , avg_candidates , len (n_candidates_collector )))
54+ n_candidates_collector = list ()
55+
56+ if is_user_answers_suggested_event (event , hour_threshold_suggested_answer ):
57+
58+ suggestable_questions = get_suggestable_questions (event .answer_date , cached_data , only_open_questions_suggestable , hour_threshold_suggested_answer , filter_nan_asker_id )
59+ if len (suggestable_questions ) == 0 :
60+ # warnings.warn("For answer id {} (to question {}) there was not a single suggestable question".format(event.answer_id, event.question_id))
61+ continue
62+
63+ n_candidates_collector .append (len (suggestable_questions ))
64+
65+ # erst appenden wenn ueber einer bestimmten zeit?
66+ feats = feature_collection .compute_features (event .answerer_user_id , suggestable_questions , event .answer_date )
67+ label = suggestable_questions .question_id .values == event .question_id
68+
69+ # add some more information
70+ feats ["question_id" ] = suggestable_questions .question_id .values .tolist () # remember question ids
71+ feats ["decision_time" ] = q_a_pair_counter # for MRR need to remember groups
72+ feats ["label" ] = label .astype (int )
73+ feats ["answer_date" ] = pd .Series ([event .answer_date for _ in range (len (feats ))])
74+
75+ all_feates_collector .append (feats )
76+
77+ assert (np .sum (np .asarray (label ).astype (int ))== 1 )
78+
79+
80+ feature_collection .update_pos_event (event ) # update features in any case
81+
82+ # save inbetween and clear variables in order to backup
83+ if (q_a_pair_counter + 1 ) % save_every == 0 :
84+ save_name = "feature_data_" + str ((q_a_pair_counter + 1 )// save_every )+ ".csv"
85+ features_table = pd .concat (all_feates_collector , axis = 0 )
86+ features_table .to_csv (os .path .join (save_dir , save_name ), index = False )
87+ print ("Successfully saved data inbetween" , save_name )
88+ del features_table
89+ del all_feates_collector
90+ all_feates_collector = list ()
91+
92+
93+ q_a_pair_counter += 1
94+
95+ if (q_a_pair_counter + 1 ) % save_every != 0 : # last batch hasn't been saved
96+ save_name = "feature_data_" + str ((q_a_pair_counter + 1 )// save_every + 1 )+ ".csv"
97+ features_table = pd .concat (all_feates_collector , axis = 0 )
98+ features_table .to_csv (os .path .join (save_dir , save_name ), index = False )
99+ print ("Successfully saved last data" )
100+ print ("Finished" )
0 commit comments