-
Notifications
You must be signed in to change notification settings - Fork 67
Expand file tree
/
Copy pathtrain.py
More file actions
79 lines (61 loc) · 2.5 KB
/
train.py
File metadata and controls
79 lines (61 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# simple Lambda function training a scikit-learn model on the digits classification dataset
# see https://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html
import os
import boto3
import numpy
from sklearn import datasets, svm, metrics
from sklearn.utils import Bunch
from sklearn.model_selection import train_test_split
from joblib import dump, load
import io
def handler(event, context):
digits = load_digits()
# flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Create a classifier: a support vector classifier
clf = svm.SVC(gamma=0.001)
# Split data into 50% train and 50% test subsets
X_train, X_test, y_train, y_test = train_test_split(
data, digits.target, test_size=0.5, shuffle=False
)
# Learn the digits on the train subset
clf.fit(X_train, y_train)
# Dump the trained model to S3
s3_client = boto3.client("s3")
buffer = io.BytesIO()
dump(clf, buffer)
s3_client.put_object(Body=buffer.getvalue(), Bucket="reproducible-ml", Key="model.joblib")
# Save the test-set to the S3 bucket
numpy.save('/tmp/test-set.npy', X_test)
with open('/tmp/test-set.npy', 'rb') as f:
s3_client.put_object(Body=f, Bucket="reproducible-ml", Key="test-set.npy")
def load_digits(*, n_class=10, return_X_y=False, as_frame=False):
# download files from S3
s3_client = boto3.client("s3")
s3_client.download_file(Bucket="reproducible-ml", Key="digits.csv.gz", Filename="/tmp/digits.csv.gz")
data = numpy.loadtxt('/tmp/digits.csv.gz', delimiter=',')
target = data[:, -1].astype(numpy.int, copy=False)
flat_data = data[:, :-1]
images = flat_data.view()
images.shape = (-1, 8, 8)
if n_class < 10:
idx = target < n_class
flat_data, target = flat_data[idx], target[idx]
images = images[idx]
feature_names = ['pixel_{}_{}'.format(row_idx, col_idx)
for row_idx in range(8)
for col_idx in range(8)]
frame = None
target_columns = ['target', ]
if as_frame:
frame, flat_data, target = datasets._convert_data_dataframe(
"load_digits", flat_data, target, feature_names, target_columns)
if return_X_y:
return flat_data, target
return Bunch(data=flat_data,
target=target,
frame=frame,
feature_names=feature_names,
target_names=numpy.arange(10),
images=images)