|
14 | 14 | */ |
15 | 15 |
|
16 | 16 |
|
17 | | -import { tensor, round } from "@tensorflow/tfjs"; |
| 17 | +import * as tf from "@tensorflow/tfjs"; |
18 | 18 | import { variance, std } from 'mathjs'; |
19 | 19 | import { Utils } from "./utils"; |
20 | 20 | import { Str } from "./strings"; |
@@ -55,7 +55,7 @@ export class Series extends NDframe { |
55 | 55 | * @returns {1D Tensor} |
56 | 56 | */ |
57 | 57 | get tensor() { |
58 | | - return tensor(this.values).asType(this.dtypes[0]); |
| 58 | + return tf.tensor(this.values).asType(this.dtypes[0]); |
59 | 59 | } |
60 | 60 |
|
61 | 61 |
|
@@ -95,30 +95,22 @@ export class Series extends NDframe { |
95 | 95 | } |
96 | 96 |
|
97 | 97 | /** |
98 | | - * Returns n number of random rows in a Series |
99 | | - * @param {rows} number of rows to return |
100 | | - * @returns {Series} |
| 98 | + * Gets [num] number of random rows in a dataframe |
| 99 | + * @param {num} rows --> The number of rows to return |
| 100 | + * @param {seed} seed --> (Optional) An integer specifying the random seed that will be used to create the distribution. |
| 101 | + * @returns {Promise} resolves to a Series object |
101 | 102 | */ |
102 | | - sample(num = 5) { |
103 | | - if (num > this.values.length || num < 1) { |
104 | | - let config = { columns: this.column_names }; |
105 | | - return new Series(this.values, config); |
106 | | - } else { |
107 | | - let values = this.values; |
108 | | - let idx = this.index; |
109 | | - let new_values = []; |
110 | | - let new_idx = []; |
111 | | - let rand_nums = utils.__shuffle(num, idx); |
112 | | - |
113 | | - rand_nums.forEach((i) => { |
114 | | - new_values.push(values[i]); |
115 | | - new_idx.push(idx[i]); |
116 | | - }); |
117 | | - let config = { columns: this.column_names, index: new_idx }; |
118 | | - let sf = new Series(new_values, config); |
119 | | - return sf; |
120 | | - |
| 103 | + async sample(num = 5, seed = 1) { |
| 104 | + if (num > this.shape[0]) { |
| 105 | + throw new Error("Sample size n cannot be bigger than size of dataset"); |
121 | 106 | } |
| 107 | + if (num < -1 || num == 0) { |
| 108 | + throw new Error("Sample size cannot be less than -1 or 0"); |
| 109 | + } |
| 110 | + num = num === -1 ? this.shape[0] : num; |
| 111 | + const shuffled_index = await tf.data.array(this.index).shuffle(num, seed).take(num).toArray(); |
| 112 | + const sf = this.iloc(shuffled_index); |
| 113 | + return sf; |
122 | 114 | } |
123 | 115 |
|
124 | 116 | /** |
@@ -250,7 +242,7 @@ export class Series extends NDframe { |
250 | 242 | mean() { |
251 | 243 | utils._throw_str_dtype_error(this, 'mean'); |
252 | 244 | let values = utils._remove_nans(this.values); |
253 | | - let mean = tensor(values).mean().arraySync(); |
| 245 | + let mean = tf.tensor(values).mean().arraySync(); |
254 | 246 | return mean; |
255 | 247 | } |
256 | 248 |
|
@@ -382,7 +374,7 @@ export class Series extends NDframe { |
382 | 374 | round(dp) { |
383 | 375 | if (utils.__is_undefined(dp)) { |
384 | 376 | //use tensorflow round function to roound to the nearest whole number |
385 | | - let result = round(this.row_data_tensor).arraySync(); |
| 377 | + let result = tf.round(this.row_data_tensor).arraySync(); |
386 | 378 | return new Series(result, { columns: this.column_names, index: this.index }); |
387 | 379 |
|
388 | 380 | } else { |
|
0 commit comments