Skip to content

Commit 36a87ff

Browse files
committed
FIX: Changes as per new index_select; imports.
1 parent 17a3579 commit 36a87ff

12 files changed

Lines changed: 20 additions & 205 deletions

deepxml/libs/dataset.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
1-
import torch
2-
import _pickle as pickle
31
import os
4-
import sys
5-
from scipy.sparse import lil_matrix
62
import numpy as np
7-
from sklearn.preprocessing import normalize
83
from .dataset_base import DatasetBase, DatasetTensor
9-
import xclib.data.data_utils as data_utils
104
from .dist_utils import Partitioner
11-
import operator
125
from xclib.utils.sparse import _map
136
from .shortlist_handler import construct_handler
147

deepxml/libs/dataset_base.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
import torch
2-
import _pickle as pickle
2+
import pickle
33
import os
4-
import sys
5-
from scipy.sparse import lil_matrix
64
import numpy as np
7-
from sklearn.preprocessing import normalize
8-
import xclib.data.data_utils as data_utils
9-
import operator
10-
from .lookup import Table, PartitionedTable
115
from .features import construct as construct_f
126
from .labels import construct as construct_l
137

@@ -44,7 +38,7 @@ def construct(self, data_dir, fname, data, indices, normalize, _type):
4438
data = construct_f(data_dir, fname, data, normalize, _type)
4539
if indices is not None:
4640
indices = np.loadtxt(indices, dtype=np.int64)
47-
data.index_select(indices)
41+
data._index_select(indices)
4842
return data
4943

5044
def __len__(self):
@@ -126,12 +120,12 @@ def __init__(self, data_dir, fname_features, fname_labels,
126120
def _remove_samples_wo_features_and_labels(self):
127121
"""Remove instances if they don't have any feature or label
128122
"""
129-
indices = self.features.get_valid(axis=1)
123+
indices = self.features.get_valid_indices(axis=1)
130124
if self.labels is not None:
131-
indices_labels = self.labels.get_valid(axis=1)
125+
indices_labels = self.labels.get_valid_indices(axis=1)
132126
indices = np.intersect1d(indices, indices_labels)
133-
self.labels.index_select(indices, axis=0)
134-
self.features.index_select(indices, axis=0)
127+
self.labels._index_select(indices, axis=0)
128+
self.features._index_select(indices, axis=0)
135129

136130
def index_select(self, feature_indices, label_indices):
137131
"""Transform feature and label matrix to specified
@@ -145,11 +139,11 @@ def _get_split_id(fname):
145139
if label_indices is not None:
146140
self._split = _get_split_id(label_indices)
147141
label_indices = np.loadtxt(label_indices, dtype=np.int32)
148-
self.labels.index_select(label_indices, axis=1)
142+
self.labels._index_select(label_indices, axis=1)
149143
if feature_indices is not None:
150144
self._split = _get_split_id(feature_indices)
151145
feature_indices = np.loadtxt(feature_indices, dtype=np.int32)
152-
self.features.index_select(feature_indices, axis=1)
146+
self.features._index_select(feature_indices, axis=1)
153147

154148
def load_features(self, data_dir, fname, X,
155149
normalize_features, feature_type):
@@ -216,7 +210,7 @@ def _process_labels_predict(self, data_obj):
216210
example
217211
"""
218212
valid_labels = data_obj['valid_labels']
219-
self.labels.index_select(valid_labels)
213+
self.labels._index_select(valid_labels, axis=1)
220214

221215
def _process_labels(self, model_dir):
222216
"""Process labels to handle labels without any training instance;

deepxml/libs/features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from xclib.data.features import FeaturesBase, DenseFeatures, SparseFeatures
22
import numpy as np
3-
import _pickle as pickle
4-
from xclib.data import data_utils
5-
import os
63
from operator import itemgetter
74

85

deepxml/libs/model.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
import logging
2-
import math
32
import os
43
import time
5-
from scipy.sparse import lil_matrix
6-
import _pickle as pickle
74
from .model_base import ModelBase
85
import torch.utils.data
9-
from torch.utils.data import DataLoader
106
from .features import DenseFeatures
11-
import numpy as np
12-
import sys
13-
import libs.utils as utils
147
from xclib.utils.matrix import SMatrix
158
from xclib.utils.sparse import sigmoid
169
from tqdm import tqdm

deepxml/libs/model_base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import logging
2-
import math
32
import os
43
import time
5-
from scipy.sparse import lil_matrix, issparse
6-
import _pickle as pickle
4+
from scipy.sparse import issparse
75
import sys
86
import torch.utils.data
97
from torch.utils.data import DataLoader

deepxml/libs/predictions.py

Lines changed: 0 additions & 144 deletions
This file was deleted.

deepxml/libs/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
import _pickle as pickle
2+
import pickle
33
from functools import partial
44

55

deepxml/libs/shortlist.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
import numpy as np
2-
import _pickle as pickle
3-
from scipy.sparse import csr_matrix, diags
1+
import pickle
42
from xclib.utils.sparse import topk, csr_from_arrays
5-
import os
6-
import numba
73
from xclib.utils.shortlist import Shortlist
84
from xclib.utils.shortlist import ShortlistCentroids
95
from xclib.utils.shortlist import ShortlistInstances

deepxml/libs/shortlist_handler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import numpy as np
2-
import _pickle as pickle
32
from .dist_utils import Partitioner
4-
import operator
53
import os
6-
from .lookup import Table, PartitionedTable
74
from .sampling import NegativeSampler
85
from scipy.sparse import load_npz
96
from xclib.utils import sparse as sp

deepxml/libs/tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Tracking object; Maintain history of loss; accuracy etc.
33
"""
44

5-
import _pickle as pickle
5+
import pickle
66

77

88
class Tracking(object):

0 commit comments

Comments
 (0)