import abc
from copy import copy
class Function(metaclass=abc.ABCMeta):
def __init__(self) -> None:
self._buffer = list()
@classmethod
@abc.abstractmethod
def forward(cls, *args, **kwargs):
pass
@classmethod
@abc.abstractmethod
def backward(cls, *args, **kwargs):
pass
def saved_for_backward(self, *args):
self._buffer.append(args)
def get_saved_buffer(self):
return self._buffer.pop(-1)
@classmethod
def apply(cls, *args, **kwargs):
obj = cls()
result = obj.forward(*args, **kwargs)
# add grad func for args
if result.required_grad:
result.saved_ctx = (args, obj.backward)
return result
def as_var(data):
if isinstance(data, Varible):
return data
return Varible(float(data), required_grad=False)
class Varible(object):
def __init__(self, value, required_grad = False) -> None:
self.required_grad = required_grad
self.grad = None
self.data = value
self.saved_ctx = None
def backward(self, grad = None, retain_graph = False):
grad = grad or 1.0
if not self.required_grad:
return
if self.saved_ctx is not None:
vars, backward_func = self.saved_ctx
grads = backward_func(grad)
for var, var_grad in zip(vars, grads):
if var.required_grad:
var.grad = var.grad or 0.0
var.grad += var_grad
var.backward(var_grad, retain_graph)
if not retain_graph:
self.saved_ctx = None
def copy_(self, other):
self.data = copy(other.data)
self.required_grad = copy(other.required_grad)
self.grad = copy(other.grad)
self.saved_ctx = copy(other.saved_ctx)
def __iadd__(self, other):
ret = self.__add__(other)
self.copy_(ret)
def __isub__(self, other):
ret = self.__sub__(other)
self.copy_(ret)
def __imul__(self, other):
ret = self.__mul__(other)
self.copy_(ret)
def __idiv__(self, other):
ret = self.__div__(other)
self.copy_(ret)
def __radd__(self, other):
return self.__add__(other)
def __rsub__(self, other):
return self.__sub__(other)
def __rmul__(self, other):
return self.__mul__(other)
def __rtruediv__(self, other):
return self.__truediv__(other)
def __add__(self, other):
return AddFunction.apply(self, as_var(other))
def __sub__(self, other):
return SubFunction.apply(self, as_var(other))
def __mul__(self, other):
return MulFunction.apply(self, as_var(other))
def __truediv__(self, other):
return DivFunction.apply(self, as_var(other))
def __str__(self) -> str:
return f"data={self.data}, required_grad={self.required_grad}, grad={self.grad}"
def __repr__(self) -> str:
return self.__str__()
class AddFunction(Function):
def forward(ctx, x, y):
required_grad = x.required_grad or y.required_grad
result_data = x.data + y.data
result = Varible(result_data, required_grad=required_grad)
return result
def backward(ctx, grad):
# grad for x and y
return grad, grad
class SubFunction(Function):
def forward(ctx, x, y):
required_grad = x.required_grad or y.required_grad
result_data = x.data - y.data
result = Varible(result_data, required_grad=required_grad)
return result
def backward(ctx, grad):
# grad for x and y
return grad, -grad
class MulFunction(Function):
def forward(ctx, x, y):
required_grad = x.required_grad or y.required_grad
result_data = x.data * y.data
ctx.saved_for_backward(x.data, y.data)
result = Varible(result_data, required_grad=required_grad)
return result
def backward(ctx, grad):
# grad for x and y
x, y = ctx.get_saved_buffer()
return grad * y, grad * x
class DivFunction(Function):
def forward(ctx, x, y):
required_grad = x.required_grad or y.required_grad
result_data = x.data / y.data
ctx.saved_for_backward(x.data, y.data)
result = Varible(result_data, required_grad=required_grad)
return result
def backward(ctx, grad):
# grad for x and y
x, y = ctx.get_saved_buffer()
return grad / y, - grad * x * (y ** -2)
if __name__ == "__main__":
a = Varible(1.0, required_grad=True)
b = Varible(2.0, required_grad=True)
y = 2 * a / b + 3 * a * a
y.backward()
print(y)
print(b)
print(a)