import pdb
from abc import ABC, abstractmethod
import math
import torch
[docs]class Kernel(ABC):
[docs] @abstractmethod
def buildKernelMatrix(self, X1, X2=None):
pass
[docs] @abstractmethod
def buildKernelMatrixDiag(self, X):
pass
[docs] def getParams(self):
return self._params
[docs] def setParams(self, params):
self._params = params
[docs]class ExponentialQuadraticKernel(Kernel):
def __init__(self, scale=None, lengthScale=None, dtype=torch.double):
paramIsNone = torch.tensor([scale is None, lengthScale is None])
self._params = torch.zeros(torch.sum(paramIsNone), dtype=dtype)
if scale is not None:
self._scale = scale
self._scaleFixed = True
else:
self._scaleFixed = False
if lengthScale is not None:
self._lengthScale = lengthScale
self._lengthScaleFixed = True
else:
self._lengthScaleFixed = False
[docs] def buildKernelMatrix(self, X1, X2=None):
scale, lengthScale = self._getAllParams(params=self._params)
if X2 is None:
X2 = X1
if X1.ndim==3:
distance = (X1-X2.transpose(1, 2))**2
else:
distance = (X1.reshape(-1,1)-X2.reshape(1,-1))**2
covMatrix = scale**2*torch.exp(-.5*distance/lengthScale**2)
return covMatrix
[docs] def buildKernelMatrixDiag(self, X):
scale, lengthScale = self._getAllParams(params=self._params)
covMatrixDiag = scale**2*torch.ones(X.shape, dtype=X.dtype)
return covMatrixDiag
def _getAllParams(self, params):
if not self._scaleFixed and not self._lengthScaleFixed:
scale = self._params[0]
lengthScale = self._params[1]
elif self._scaleFixed and not self._lengthScaleFixed:
scale = self._scale
lengthScale = self._params[0]
elif not self._scaleFixed and self._lengthScaleFixed:
scale = self._params[0]
lengthScale = self._lengthScale
else:
raise ValueError("Scale and lengthScale cannot be both fixed")
return scale, lengthScale
[docs]class PeriodicKernel(Kernel):
def __init__(self, scale=None, lengthScale=None, period=None, dtype=torch.double):
paramIsNone = torch.tensor([scale is None, lengthScale is None, period is None])
self._params = torch.zeros(torch.sum(paramIsNone), dtype=dtype)
if scale is not None:
self._scale = scale
self._scaleFixed = True
else:
self._scaleFixed = False
if lengthScale is not None:
self._lengthScale = lengthScale
self._lengthScaleFixed = True
else:
self._lengthScaleFixed = False
if period is not None:
self._period = period
self._periodFixed = True
else:
self._periodFixed = False
[docs] def buildKernelMatrix(self, X1, X2=None):
scale, lengthScale, period = self._getAllParams(params=self._params)
if X2 is None:
X2 = X1
if X1.ndim==3:
sDistance = X1-X2.transpose(1, 2)
else:
sDistance = X1.reshape(-1,1)-X2.reshape(1,-1)
rr = math.pi*sDistance/period
covMatrix = scale**2*torch.exp(-2*torch.sin(rr)**2/lengthScale**2)
return covMatrix
[docs] def buildKernelMatrixDiag(self, X):
scale, lengthScale, period = self._getAllParams(params=self._params)
covMatrixDiag = scale**2*torch.ones(X.shape, dtype=X.dtype)
return covMatrixDiag
def _getAllParams(self, params):
if not self._scaleFixed and not self._lengthScaleFixed and not self._periodFixed:
scale = self._params[0]
lengthScale = self._params[1]
period = self._params[2]
elif self._scaleFixed and not self._lengthScaleFixed and not self._periodFixed:
scale = self._scale
lengthScale = self._params[0]
period = self._params[1]
elif not self._scaleFixed and self._lengthScaleFixed and not self._periodFixed:
scale = self._params[0]
lengthScale = self._lengthScale
period = self._params[1]
elif self._scaleFixed and self._lengthScaleFixed and not self._periodFixed:
scale = self._params[0]
lengthScale = self._params[1]
period = self._period
elif self._scaleFixed and self._lengthScaleFixed and not self._periodFixed:
scale = self._scale
lengthScale = self._lengthScale
period = self._params[0]
elif self._scaleFixed and not self._lengthScaleFixed and self._periodFixed:
scale = self._scale
lengthScale = self._params[0]
period = self.period
elif not self._scaleFixed and self._lengthScaleFixed and self._periodFixed:
scale = self._params[0]
lengthScale = self._lengthScale
period = self._period
else:
raise ValueError("Scale and lengthScale cannot be both fixed")
return scale, lengthScale, period
'''
class AddDiagKernel(Kernel):
def __init__(self, kernel, epsilon=1e-5):
self.__kernel = kernel
self.__epsilon = epsilon
def buildKernelMatrix(self, X1, X2=None):
covMatrix = self.__kernel.buildKernelMatrix(X1=X1, X2=X2)
covMatrixPlusDiag = (covMatrix +
self.__epsilon*torch.eye(n=covMatrix.shape[0],
dtype=X1.dtype))
return covMatrixPlusDiag
def buildKernelMatrixDiag(self, X):
return self.__kernel.buildKernelMatrixDiag(X=X)
def setParams(self, params):
self.__kernel.setParams(params)
def getParams(self):
params = self.__kernel.getParams()
return params
'''