Skip to content

Commit 25d22da

Browse files
committed
Update sample function in browser and node env
1 parent eac956c commit 25d22da

4 files changed

Lines changed: 56 additions & 62 deletions

File tree

danfojs-browser/src/core/series.js

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
*/
1515

1616

17-
import { tensor, round } from "@tensorflow/tfjs";
17+
import * as tf from "@tensorflow/tfjs";
1818
import { variance, std } from 'mathjs';
1919
import { Utils } from "./utils";
2020
import { Str } from "./strings";
@@ -55,7 +55,7 @@ export class Series extends NDframe {
5555
* @returns {1D Tensor}
5656
*/
5757
get tensor() {
58-
return tensor(this.values).asType(this.dtypes[0]);
58+
return tf.tensor(this.values).asType(this.dtypes[0]);
5959
}
6060

6161

@@ -95,30 +95,22 @@ export class Series extends NDframe {
9595
}
9696

9797
/**
98-
* Returns n number of random rows in a Series
99-
* @param {rows} number of rows to return
100-
* @returns {Series}
98+
* Gets [num] number of random rows in a dataframe
99+
* @param {num} rows --> The number of rows to return
100+
* @param {seed} seed --> (Optional) An integer specifying the random seed that will be used to create the distribution.
101+
* @returns {Promise} resolves to a Series object
101102
*/
102-
sample(num = 5) {
103-
if (num > this.values.length || num < 1) {
104-
let config = { columns: this.column_names };
105-
return new Series(this.values, config);
106-
} else {
107-
let values = this.values;
108-
let idx = this.index;
109-
let new_values = [];
110-
let new_idx = [];
111-
let rand_nums = utils.__shuffle(num, idx);
112-
113-
rand_nums.forEach((i) => {
114-
new_values.push(values[i]);
115-
new_idx.push(idx[i]);
116-
});
117-
let config = { columns: this.column_names, index: new_idx };
118-
let sf = new Series(new_values, config);
119-
return sf;
120-
103+
async sample(num = 5, seed = 1) {
104+
if (num > this.shape[0]) {
105+
throw new Error("Sample size n cannot be bigger than size of dataset");
121106
}
107+
if (num < -1 || num == 0) {
108+
throw new Error("Sample size cannot be less than -1 or 0");
109+
}
110+
num = num === -1 ? this.shape[0] : num;
111+
const shuffled_index = await tf.data.array(this.index).shuffle(num, seed).take(num).toArray();
112+
const sf = this.iloc(shuffled_index);
113+
return sf;
122114
}
123115

124116
/**
@@ -250,7 +242,7 @@ export class Series extends NDframe {
250242
mean() {
251243
utils._throw_str_dtype_error(this, 'mean');
252244
let values = utils._remove_nans(this.values);
253-
let mean = tensor(values).mean().arraySync();
245+
let mean = tf.tensor(values).mean().arraySync();
254246
return mean;
255247
}
256248

@@ -382,7 +374,7 @@ export class Series extends NDframe {
382374
round(dp) {
383375
if (utils.__is_undefined(dp)) {
384376
//use tensorflow round function to roound to the nearest whole number
385-
let result = round(this.row_data_tensor).arraySync();
377+
let result = tf.round(this.row_data_tensor).arraySync();
386378
return new Series(result, { columns: this.column_names, index: this.index });
387379

388380
} else {

danfojs-browser/tests/core/series.js

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/* eslint-disable no-undef */
2-
const tf = require("@tensorflow/tfjs-core");
2+
const tf = require("@tensorflow/tfjs");
33

44
describe("Series", function () {
55
describe("tensor", function () {
@@ -78,20 +78,25 @@ describe("Series", function () {
7878
});
7979

8080
describe("sample", function () {
81-
it("Samples n number of random elements from a DataFrame", function () {
81+
it("Samples n number of random elements from a DataFrame", async function () {
8282
let data = [ 1, 2, 3, 4, 5, 620, 30, 40, 39, 89, 78 ];
8383
let sf = new dfd.Series(data);
84-
assert.deepEqual(sf.sample(7).values.length, 7);
84+
assert.deepEqual((await sf.sample(7)).values.length, 7);
8585
});
86-
it("Return all values if n of sample is greater than lenght of Dataframe", function () {
86+
it("Return all values if n of sample -1", async function () {
8787
let data = [ 1, 2, 3, 4, 5, 620, 30, 40, 39, 89, 78 ];
8888
let sf = new dfd.Series(data);
89-
assert.deepEqual(sf.sample(21).values.length, data.length);
89+
assert.deepEqual((await sf.sample(-1)).values.length, data.length);
9090
});
91-
it("Return all values if n of sample is less than 1", function () {
91+
it("Throw error if n is greater than lenght of Series", async function () {
9292
let data = [ 1, 2, 3, 4, 5, 620, 30, 40, 39, 89, 78 ];
9393
let sf = new dfd.Series(data);
94-
assert.deepEqual(sf.sample(-2).values.length, data.length);
94+
try {
95+
await sf.sample(100);
96+
} catch (e) {
97+
expect(e).to.be.instanceOf(Error);
98+
expect(e.message).to.eql('Sample size n cannot be bigger than size of dataset');
99+
}
95100
});
96101
});
97102

danfojs-node/src/core/series.js

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -94,30 +94,22 @@ export class Series extends NDframe {
9494
}
9595

9696
/**
97-
* Returns n number of random rows in a Series
98-
* @param {rows} number of rows to return
99-
* @returns {Series}
97+
* Gets [num] number of random rows in a dataframe
98+
* @param {num} rows --> The number of rows to return
99+
* @param {seed} seed --> (Optional) An integer specifying the random seed that will be used to create the distribution.
100+
* @returns {Promise} resolves to a Series object
100101
*/
101-
sample(num = 5) {
102-
if (num > this.values.length || num < 1) {
103-
let config = { columns: this.column_names };
104-
return new Series(this.values, config);
105-
} else {
106-
let values = this.values;
107-
let idx = this.index;
108-
let new_values = [];
109-
let new_idx = [];
110-
let rand_nums = utils.__shuffle(num, idx);
111-
112-
rand_nums.forEach((i) => {
113-
new_values.push(values[i]);
114-
new_idx.push(idx[i]);
115-
});
116-
let config = { columns: this.column_names, index: new_idx };
117-
let sf = new Series(new_values, config);
118-
return sf;
119-
102+
async sample(num = 5, seed = 1) {
103+
if (num > this.shape[0]) {
104+
throw new Error("Sample size n cannot be bigger than size of dataset");
120105
}
106+
if (num < -1 || num == 0) {
107+
throw new Error("Sample size cannot be less than -1 or 0");
108+
}
109+
num = num === -1 ? this.shape[0] : num;
110+
const shuffled_index = await tf.data.array(this.index).shuffle(num, seed).take(num).toArray();
111+
const sf = this.iloc(shuffled_index);
112+
return sf;
121113
}
122114

123115
/**

danfojs-node/tests/core/series.js

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { assert } from "chai";
1+
import { assert, expect } from "chai";
22
import { Series } from "../../src/core/series";
33
import * as tf from '@tensorflow/tfjs-node';
44

@@ -79,20 +79,25 @@ describe("Series", function () {
7979
});
8080

8181
describe("sample", function () {
82-
it("Samples n number of random elements from a DataFrame", function () {
82+
it("Samples n number of random elements from a DataFrame", async function () {
8383
let data = [ 1, 2, 3, 4, 5, 620, 30, 40, 39, 89, 78 ];
8484
let sf = new Series(data);
85-
assert.deepEqual(sf.sample(7).values.length, 7);
85+
assert.deepEqual((await sf.sample(7)).values.length, 7);
8686
});
87-
it("Return all values if n of sample is greater than lenght of Dataframe", function () {
87+
it("Return all values if n of sample -1", async function () {
8888
let data = [ 1, 2, 3, 4, 5, 620, 30, 40, 39, 89, 78 ];
8989
let sf = new Series(data);
90-
assert.deepEqual(sf.sample(21).values.length, data.length);
90+
assert.deepEqual((await sf.sample(-1)).values.length, data.length);
9191
});
92-
it("Return all values if n of sample is less than 1", function () {
92+
it("Throw error if n is greater than lenght of Series", async function () {
9393
let data = [ 1, 2, 3, 4, 5, 620, 30, 40, 39, 89, 78 ];
9494
let sf = new Series(data);
95-
assert.deepEqual(sf.sample(-2).values.length, data.length);
95+
try {
96+
await sf.sample(100);
97+
} catch (e) {
98+
expect(e).to.be.instanceOf(Error);
99+
expect(e.message).to.eql('Sample size n cannot be bigger than size of dataset');
100+
}
96101
});
97102
});
98103

0 commit comments

Comments
 (0)