Skip to content

Commit

Permalink
Add ShrunkCovariance (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
norm4nn authored Nov 19, 2024
1 parent 58ddd65 commit f84177f
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 41 deletions.
44 changes: 5 additions & 39 deletions lib/scholar/covariance/ledoit_wolf.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ defmodule Scholar.Covariance.LedoitWolf do
defstruct [:covariance, :shrinkage, :location]

opts_schema = [
assume_centered: [
assume_centered?: [
default: false,
type: :boolean,
doc: """
Expand Down Expand Up @@ -93,7 +93,7 @@ defmodule Scholar.Covariance.LedoitWolf do
iex> key = Nx.Random.key(0)
iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0, 0.0]), Nx.tensor([[3.0, 2.0, 1.0], [1.0, 2.0, 3.0], [1.3, 1.0, 2.2]]), shape: {10}, type: :f32)
iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered: true)
iex> cov = Scholar.Covariance.LedoitWolf.fit(x, assume_centered?: true)
iex> cov.covariance
#Nx.Tensor<
f32[3][3]
Expand All @@ -110,7 +110,7 @@ defmodule Scholar.Covariance.LedoitWolf do
end

defnp fit_n(x, opts) do
{x, location} = center(x, opts)
{x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?])

{covariance, shrinkage} =
ledoit_wolf(x)
Expand All @@ -122,23 +122,6 @@ defmodule Scholar.Covariance.LedoitWolf do
}
end

defnp center(x, opts) do
x =
case Nx.shape(x) do
{_} -> Nx.new_axis(x, 1)
_ -> x
end

location =
if opts[:assume_centered] do
0
else
Nx.mean(x, axes: [0])
end

{x - location, location}
end

defnp ledoit_wolf(x) do
case Nx.shape(x) do
{_n, 1} ->
Expand All @@ -149,23 +132,6 @@ defmodule Scholar.Covariance.LedoitWolf do
end
end

defnp empirical_covariance(x) do
n = Nx.axis_size(x, 0)

covariance = Nx.dot(x, [0], x, [0]) / n

case Nx.shape(covariance) do
{} -> Nx.reshape(covariance, {1, 1})
_ -> covariance
end
end

defnp trace(x) do
x
|> Nx.take_diagonal()
|> Nx.sum()
end

defnp ledoit_wolf_shrinkage(x) do
case Nx.shape(x) do
{_, 1} ->
Expand All @@ -182,9 +148,9 @@ defmodule Scholar.Covariance.LedoitWolf do

defnp ledoit_wolf_shrinkage_complex(x) do
{num_samples, num_features} = Nx.shape(x)
emp_cov = empirical_covariance(x)
emp_cov = Scholar.Covariance.Utils.empirical_covariance(x)

emp_cov_trace = trace(emp_cov)
emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov)
mu = Nx.sum(emp_cov_trace) / num_features

flatten_delta = Nx.flatten(emp_cov)
Expand Down
119 changes: 119 additions & 0 deletions lib/scholar/covariance/shrunk_covariance.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
defmodule Scholar.Covariance.ShrunkCovariance do
@moduledoc """
Covariance estimator with shrinkage.
"""
import Nx.Defn

@derive {Nx.Container, containers: [:covariance, :location]}
defstruct [:covariance, :location]

opts_schema = [
assume_centered?: [
default: false,
type: :boolean,
doc: """
If `true`, data will not be centered before computation.
Useful when working with data whose mean is almost, but not exactly
zero.
If `false`, data will be centered before computation.
"""
],
shrinkage: [
default: 0.1,
type: :float,
doc: "Coefficient in the convex combination used for the computation
of the shrunk estimate. Range is [0, 1]."
]
]

@opts_schema NimbleOptions.new!(opts_schema)
@doc """
Fit the shrunk covariance model to `x`.
## Options
#{NimbleOptions.docs(@opts_schema)}
## Return Values
The function returns a struct with the following parameters:
* `:covariance` - Tensor of shape `{num_features, num_features}`. Estimated covariance matrix.
* `:location` - Tensor of shape `{num_features,}`.
Estimated location, i.e. the estimated mean.
## Examples
iex> key = Nx.Random.key(0)
iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.8, 0.3], [0.2, 0.4]]), shape: {10}, type: :f32)
iex> model = Scholar.Covariance.ShrunkCovariance.fit(x)
iex> model.covariance
#Nx.Tensor<
f32[2][2]
[
[0.7721845507621765, 0.19141492247581482],
[0.19141492247581482, 0.33952537178993225]
]
>
iex> model.location
#Nx.Tensor<
f32[2]
[0.18202415108680725, -0.09216632694005966]
>
iex> key = Nx.Random.key(0)
iex> {x, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0.0, 0.0]), Nx.tensor([[0.8, 0.3], [0.2, 0.4]]), shape: {10}, type: :f32)
iex> model = Scholar.Covariance.ShrunkCovariance.fit(x, shrinkage: 0.4)
iex> model.covariance
#Nx.Tensor<
f32[2][2]
[
[0.7000747323036194, 0.1276099532842636],
[0.1276099532842636, 0.41163527965545654]
]
>
iex> model.location
#Nx.Tensor<
f32[2]
[0.18202415108680725, -0.09216632694005966]
>
"""

deftransform fit(x, opts \\ []) do
fit_n(x, NimbleOptions.validate!(opts, @opts_schema))
end

defnp fit_n(x, opts) do
shrinkage = opts[:shrinkage]

if shrinkage < 0 or shrinkage > 1 do
raise ArgumentError,
"""
expected :shrinkage option to be in [0, 1] range, \
got shrinkage: #{inspect(Nx.shape(x))}\
"""
end

{x, location} = Scholar.Covariance.Utils.center(x, opts[:assume_centered?])

covariance =
Scholar.Covariance.Utils.empirical_covariance(x)
|> shrunk_covariance(shrinkage)

%__MODULE__{
covariance: covariance,
location: location
}
end

defnp shrunk_covariance(emp_cov, shrinkage) do
num_features = Nx.axis_size(emp_cov, 1)
shrunk_cov = (1.0 - shrinkage) * emp_cov
emp_cov_trace = Scholar.Covariance.Utils.trace(emp_cov)
mu = Nx.sum(emp_cov_trace) / num_features

mask = Nx.iota(Nx.shape(shrunk_cov))
selector = Nx.remainder(mask, num_features + 1) == 0

shrunk_cov + shrinkage * mu * selector
end
end
39 changes: 39 additions & 0 deletions lib/scholar/covariance/utils.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
defmodule Scholar.Covariance.Utils do
@moduledoc false
import Nx.Defn
require Nx

defn center(x, assume_centered? \\ false) do
x =
case Nx.shape(x) do
{_} -> Nx.new_axis(x, 1)
_ -> x
end

location =
if assume_centered? do
0
else
Nx.mean(x, axes: [0])
end

{x - location, location}
end

defn empirical_covariance(x) do
n = Nx.axis_size(x, 0)

covariance = Nx.dot(x, [0], x, [0]) / n

case Nx.shape(covariance) do
{} -> Nx.reshape(covariance, {1, 1})
_ -> covariance
end
end

defn trace(x) do
x
|> Nx.take_diagonal()
|> Nx.sum()
end
end
4 changes: 2 additions & 2 deletions test/scholar/covariance/ledoit_wolf_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ defmodule Scholar.Covariance.LedoitWolfTest do
)
end

test "fit test - :assume_centered is true" do
test "fit test - :assume_centered? is true" do
key = key()

{x, _new_key} =
Expand All @@ -52,7 +52,7 @@ defmodule Scholar.Covariance.LedoitWolfTest do
type: :f32
)

model = LedoitWolf.fit(x, assume_centered: true)
model = LedoitWolf.fit(x, assume_centered?: true)

assert_all_close(
model.covariance,
Expand Down
Loading

0 comments on commit f84177f

Please sign in to comment.