Skip to content

Commit d2e8906

Browse files
authored
Improve Rasch Model implementation with learning rate and better documentation (#7)
1 parent 91eb41c commit d2e8906

2 files changed

Lines changed: 22 additions & 16 deletions

File tree

lib/irt_ruby/rasch_model.rb

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,42 @@
55
module IrtRuby
66
# A class representing the Rasch model for Item Response Theory.
77
class RaschModel
8-
def initialize(data, max_iter: 1000, tolerance: 1e-6)
8+
def initialize(data, max_iter: 1000, tolerance: 1e-6, learning_rate: 0.01)
99
@data = data
1010
@abilities = Array.new(data.row_count) { rand }
1111
@difficulties = Array.new(data.column_count) { rand }
1212
@max_iter = max_iter
1313
@tolerance = tolerance
14+
@learning_rate = learning_rate
1415
end
1516

17+
# Sigmoid function to calculate probability
1618
def sigmoid(x)
1719
1.0 / (1.0 + Math.exp(-x))
1820
end
1921

22+
# Calculate the log-likelihood of the data given the current parameters
2023
def likelihood
2124
likelihood = 0
2225
@data.row_vectors.each_with_index do |row, i|
2326
row.to_a.each_with_index do |response, j|
2427
prob = sigmoid(@abilities[i] - @difficulties[j])
25-
if response == 1
26-
likelihood += Math.log(prob)
27-
elsif response.zero?
28-
likelihood += Math.log(1 - prob)
29-
end
28+
likelihood += response == 1 ? Math.log(prob) : Math.log(1 - prob)
3029
end
3130
end
3231
likelihood
3332
end
3433

34+
# Update parameters using gradient ascent
3535
def update_parameters
3636
last_likelihood = likelihood
3737
@max_iter.times do |_iter|
3838
@data.row_vectors.each_with_index do |row, i|
3939
row.to_a.each_with_index do |response, j|
4040
prob = sigmoid(@abilities[i] - @difficulties[j])
4141
error = response - prob
42-
@abilities[i] += 0.01 * error
43-
@difficulties[j] -= 0.01 * error
42+
@abilities[i] += @learning_rate * error
43+
@difficulties[j] -= @learning_rate * error
4444
end
4545
end
4646
current_likelihood = likelihood
@@ -50,6 +50,7 @@ def update_parameters
5050
end
5151
end
5252

53+
# Fit the model to the data
5354
def fit
5455
update_parameters
5556
{ abilities: @abilities, difficulties: @difficulties }

spec/irt_ruby/rasch_model_spec.rb

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,31 @@
44

55
RSpec.describe IrtRuby::RaschModel do
66
let(:data) { Matrix[[1, 0, 1], [0, 1, 0], [1, 1, 1]] }
7-
let(:irt_model) { IrtRuby::RaschModel.new(data, max_iter: 2000) }
7+
let(:model) { IrtRuby::RaschModel.new(data, max_iter: 2000) }
8+
9+
describe "#initialize" do
10+
it "initializes with data" do
11+
expect(model.instance_variable_get(:@data)).to eq(data)
12+
end
13+
end
814

915
describe "#sigmoid" do
10-
it "calculates the sigmoid of a value" do
11-
expect(irt_model.sigmoid(0)).to be_within(0.01).of(0.5)
12-
expect(irt_model.sigmoid(2)).to be_within(0.01).of(0.88)
16+
it "calculates the sigmoid function" do
17+
expect(model.sigmoid(0)).to eq(0.5)
1318
end
1419
end
1520

1621
describe "#likelihood" do
1722
it "calculates the likelihood of the data" do
18-
expect(irt_model.likelihood).to be_a(Float)
23+
expect(model.likelihood).to be_a(Float)
1924
end
2025
end
2126

2227
describe "#fit" do
2328
it "fits the model and returns abilities and difficulties" do
24-
results = irt_model.fit
25-
expect(results[:abilities].size).to eq(3)
26-
expect(results[:difficulties].size).to eq(3)
29+
result = model.fit
30+
expect(result[:abilities].size).to eq(data.row_count)
31+
expect(result[:difficulties].size).to eq(data.column_count)
2732
end
2833
end
2934
end

0 commit comments

Comments
 (0)