diff --git a/arden_types.py b/arden_types.py index cee8992..132afbe 100644 --- a/arden_types.py +++ b/arden_types.py @@ -37,6 +37,13 @@ class BaseType(ABC): def TIME(self): return self.time + def WHERE(self,other): + if isinstance(other, ListType): + return ListType([self for b_item in other.value if b_item.value==True]) + if isinstance(other, BoolType) and other: + return self + return NullType() + class NumType(BaseType): def __init__(self, value: (int, float)): @@ -77,31 +84,31 @@ class NumType(BaseType): def __gt__(self, other): if isinstance(other, NumType): return BoolType(self.value > other.value) - return NullType + return NullType() def __lt__(self, other): if isinstance(other, NumType): return BoolType(self.value < other.value) - return NullType + return NullType() def __ge__(self, other): if isinstance(other, NumType): return BoolType(self.value >= other.value) if isinstance(other, ListType): return ListType([self >= item for item in other.value]) - return NullType + return NullType() def __le__(self, other): if isinstance(other, NumType): return BoolType(self.value <= other.value) if isinstance(other, ListType): return ListType([self <= item for item in other.value]) - return NullType + return NullType() def __pow__(self, other): if isinstance(other, NumType): return NumType(self.value ** other.value) - return NullType + return NullType() def __neg__(self): return NumType(-self.value) @@ -156,7 +163,7 @@ class BoolType(BaseType): def __and__(self,other): if isinstance(other, BoolType): return BoolType(self.value and other.value) - return NullType + return NullType() def NOT(self): return ~self @@ -174,52 +181,54 @@ class ListType(BaseType): def __gt__(self, other): if isinstance(other, NumType): return ListType([item > other for item in self.value]) - return NullType + return NullType() def __lt__(self, other): if isinstance(other, NumType): return ListType([item < other for item in self.value]) - return NullType + return NullType() def __ge__(self, other): if isinstance(other, NumType): return ListType([item >= other for item in self.value]) if isinstance(other, ListType) and len(self.value) == len(other.value): return ListType([a_item >= b_item for a_item, b_item in zip(self.value, other.value)]) - return NullType + return NullType() def __le__(self, other): if isinstance(other, NumType): return ListType([item <= other for item in self.value]) if isinstance(other, ListType) and len(self.value) == len(other.value): return ListType([a_item <= b_item for a_item, b_item in zip(self.value, other.value)]) - return NullType + return NullType() def __and__(self,other): if isinstance(other, BoolType) or isinstance(other, NumType): return ListType([item and other for item in self.value]) if isinstance(other, ListType) and len(self.value) == len(other.value): return ListType([a_item and b_item for a_item, b_item in zip(self.value, other.value)]) - return NullType + return NullType() def __or__(self,other): if isinstance(other, BoolType) or isinstance(other, NumType): return ListType([item or other for item in self.value]) if isinstance(other, ListType) and len(self.value) == len(other.value): return ListType([a_item or b_item for a_item, b_item in zip(self.value, other.value)]) - return NullType + return NullType() def __truediv__(self,other): if isinstance(other, NumType): return ListType([item / other for item in self.value]) if isinstance(other, ListType) and len(self.value) == len(other.value): return ListType([a_item / b_item for a_item, b_item in zip(self.value, other.value)]) - return NullType + return NullType() def WHERE(self,other): if isinstance(other, ListType) and len(self.value) == len(other.value): - return ListType([a_item for a_item, b_item in zip(self.value, other.value) if b_item]) - return NullType + return ListType([a_item for a_item, b_item in zip(self.value, other.value) if b_item.value==True]) + if isinstance(other, BoolType) and other: + return self + return NullType() def IS(self,a,invert=False): if a is NumType: @@ -262,12 +271,12 @@ class DateType(BaseType): def __gt__(self, other): if isinstance(other, DateType): return BoolType(self.value > other.value) - return NullType + return NullType() def __lt__(self, other): if isinstance(other, DateType): return BoolType(self.value < other.value) - return NullType + return NullType() def __str__(self): return self.value.strftime('%Y-%m-%dT%H:%M:%S')