Skip to content

Commit

Permalink
Add arithmetic supported
Browse files Browse the repository at this point in the history
  • Loading branch information
XIELongDragon committed Feb 23, 2022
1 parent e40b45c commit 30ad302
Show file tree
Hide file tree
Showing 16 changed files with 261 additions and 0 deletions.
11 changes: 11 additions & 0 deletions dialect/mysql/mysql_dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ func (mds *mysqlDialectSuite) TestBooleanOperations() {
)
}

func (mds *mysqlDialectSuite) TestArithmeticOperations() {
col := goqu.C("a")
ds := mds.GetDs("test")
mds.assertSQL(
sqlTestCase{ds: ds.Where(col.Add(1).Eq(2)), sql: "SELECT * FROM `test` WHERE ((`a` + 1) = 2)"},
sqlTestCase{ds: ds.Where(col.Sub(1).Gte(3)), sql: "SELECT * FROM `test` WHERE ((`a` - 1) >= 3)"},
sqlTestCase{ds: ds.Where(col.Mul(1)), sql: "SELECT * FROM `test` WHERE (`a` * 1)"},
sqlTestCase{ds: ds.Where(col.Div(1)), sql: "SELECT * FROM `test` WHERE (`a` / 1)"},
)
}

func (mds *mysqlDialectSuite) TestBitwiseOperations() {
col := goqu.C("a")
ds := mds.GetDs("test")
Expand Down
83 changes: 83 additions & 0 deletions exp/arithmetic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package exp

type arithmetic struct {
lhs Expression
rhs interface{}
op ArithmeticOperation
}

func NewArithmeticExpression(op ArithmeticOperation, lhs Expression, rhs interface{}) ArithmeticExpression {
return arithmetic{op: op, lhs: lhs, rhs: rhs}
}

func (a arithmetic) Clone() Expression {
return NewArithmeticExpression(a.op, a.lhs.Clone(), a.rhs)
}

func (a arithmetic) RHS() interface{} {
return a.rhs
}

func (a arithmetic) LHS() Expression {
return a.lhs
}

func (a arithmetic) Op() ArithmeticOperation {
return a.op
}

func (a arithmetic) Expression() Expression { return a }
func (a arithmetic) As(val interface{}) AliasedExpression { return NewAliasExpression(a, val) }
func (a arithmetic) Eq(val interface{}) BooleanExpression { return eq(a, val) }
func (a arithmetic) Neq(val interface{}) BooleanExpression { return neq(a, val) }
func (a arithmetic) Gt(val interface{}) BooleanExpression { return gt(a, val) }
func (a arithmetic) Gte(val interface{}) BooleanExpression { return gte(a, val) }
func (a arithmetic) Lt(val interface{}) BooleanExpression { return lt(a, val) }
func (a arithmetic) Lte(val interface{}) BooleanExpression { return lte(a, val) }
func (a arithmetic) Asc() OrderedExpression { return asc(a) }
func (a arithmetic) Desc() OrderedExpression { return desc(a) }
func (a arithmetic) Like(i interface{}) BooleanExpression { return like(a, i) }
func (a arithmetic) NotLike(i interface{}) BooleanExpression { return notLike(a, i) }
func (a arithmetic) ILike(i interface{}) BooleanExpression { return iLike(a, i) }
func (a arithmetic) NotILike(i interface{}) BooleanExpression { return notILike(a, i) }
func (a arithmetic) RegexpLike(val interface{}) BooleanExpression { return regexpLike(a, val) }
func (a arithmetic) RegexpNotLike(val interface{}) BooleanExpression { return regexpNotLike(a, val) }
func (a arithmetic) RegexpILike(val interface{}) BooleanExpression { return regexpILike(a, val) }
func (a arithmetic) RegexpNotILike(val interface{}) BooleanExpression { return regexpNotILike(a, val) }
func (a arithmetic) In(i ...interface{}) BooleanExpression { return in(a, i...) }
func (a arithmetic) NotIn(i ...interface{}) BooleanExpression { return notIn(a, i...) }
func (a arithmetic) Is(i interface{}) BooleanExpression { return is(a, i) }
func (a arithmetic) IsNot(i interface{}) BooleanExpression { return isNot(a, i) }
func (a arithmetic) IsNull() BooleanExpression { return is(a, nil) }
func (a arithmetic) IsNotNull() BooleanExpression { return isNot(a, nil) }
func (a arithmetic) IsTrue() BooleanExpression { return is(a, true) }
func (a arithmetic) IsNotTrue() BooleanExpression { return isNot(a, true) }
func (a arithmetic) IsFalse() BooleanExpression { return is(a, false) }
func (a arithmetic) IsNotFalse() BooleanExpression { return isNot(a, false) }
func (a arithmetic) Distinct() SQLFunctionExpression { return NewSQLFunctionExpression("DISTINCT", a) }
func (a arithmetic) Between(val RangeVal) RangeExpression { return between(a, val) }
func (a arithmetic) NotBetween(val RangeVal) RangeExpression { return notBetween(a, val) }
func (a arithmetic) Cast(t string) CastExpression { return NewCastExpression(a, t) }
func (a arithmetic) BitwiseInversion() BitwiseExpression { return bitwiseInversion(a) }
func (a arithmetic) BitwiseOr(val interface{}) BitwiseExpression { return bitwiseOr(a, val) }
func (a arithmetic) BitwiseAnd(val interface{}) BitwiseExpression { return bitwiseAnd(a, val) }
func (a arithmetic) BitwiseXor(val interface{}) BitwiseExpression { return bitwiseXor(a, val) }
func (a arithmetic) BitwiseLeftShift(val interface{}) BitwiseExpression {
return bitwiseLeftShift(a, val)
}
func (a arithmetic) BitwiseRightShift(val interface{}) BitwiseExpression {
return bitwiseRightShift(a, val)
}

func arithmeticAdd(lhs Expression, rhs interface{}) ArithmeticExpression {
return NewArithmeticExpression(ArithmeticAddOp, lhs, rhs)
}
func arithmeticSub(lhs Expression, rhs interface{}) ArithmeticExpression {
return NewArithmeticExpression(ArithmeticSubOp, lhs, rhs)
}
func arithmeticMul(lhs Expression, rhs interface{}) ArithmeticExpression {
return NewArithmeticExpression(ArithmeticMulOp, lhs, rhs)
}
func arithmeticDiv(lhs Expression, rhs interface{}) ArithmeticExpression {
return NewArithmeticExpression(ArithmeticDivOp, lhs, rhs)
}
4 changes: 4 additions & 0 deletions exp/bitwise.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func (b bitwise) IsNotFalse() BooleanExpression { return isNo
func (b bitwise) Distinct() SQLFunctionExpression { return NewSQLFunctionExpression("DISTINCT", b) }
func (b bitwise) Between(val RangeVal) RangeExpression { return between(b, val) }
func (b bitwise) NotBetween(val RangeVal) RangeExpression { return notBetween(b, val) }
func (b bitwise) Add(val interface{}) ArithmeticExpression { return arithmeticAdd(b, val) }
func (b bitwise) Sub(val interface{}) ArithmeticExpression { return arithmeticSub(b, val) }
func (b bitwise) Mul(val interface{}) ArithmeticExpression { return arithmeticMul(b, val) }
func (b bitwise) Div(val interface{}) ArithmeticExpression { return arithmeticDiv(b, val) }

// used internally to create a Bitwise Inversion BitwiseExpression
func bitwiseInversion(rhs Expression) BitwiseExpression {
Expand Down
4 changes: 4 additions & 0 deletions exp/bitwise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ func (bes *bitwiseExpressionSuite) TestAllOthers() {
{Ex: be.IsFalse(), Expected: exp.NewBooleanExpression(exp.IsOp, be, false)},
{Ex: be.IsNotFalse(), Expected: exp.NewBooleanExpression(exp.IsNotOp, be, false)},
{Ex: be.Distinct(), Expected: exp.NewSQLFunctionExpression("DISTINCT", be)},
{Ex: be.Add(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticAddOp, be, 1)},
{Ex: be.Sub(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticSubOp, be, 1)},
{Ex: be.Mul(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticMulOp, be, 1)},
{Ex: be.Div(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticDivOp, be, 1)},
}

for _, tc := range testCases {
Expand Down
4 changes: 4 additions & 0 deletions exp/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,7 @@ func (c cast) IsNotFalse() BooleanExpression { return isNot(c
func (c cast) Distinct() SQLFunctionExpression { return NewSQLFunctionExpression("DISTINCT", c) }
func (c cast) Between(val RangeVal) RangeExpression { return between(c, val) }
func (c cast) NotBetween(val RangeVal) RangeExpression { return notBetween(c, val) }
func (c cast) Add(val interface{}) ArithmeticExpression { return arithmeticAdd(c, val) }
func (c cast) Sub(val interface{}) ArithmeticExpression { return arithmeticSub(c, val) }
func (c cast) Mul(val interface{}) ArithmeticExpression { return arithmeticMul(c, val) }
func (c cast) Div(val interface{}) ArithmeticExpression { return arithmeticDiv(c, val) }
4 changes: 4 additions & 0 deletions exp/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ func (ces *castExpressionSuite) TestAllOthers() {
{Ex: ce.IsFalse(), Expected: exp.NewBooleanExpression(exp.IsOp, ce, false)},
{Ex: ce.IsNotFalse(), Expected: exp.NewBooleanExpression(exp.IsNotOp, ce, false)},
{Ex: ce.Distinct(), Expected: exp.NewSQLFunctionExpression("DISTINCT", ce)},
{Ex: ce.Add(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticAddOp, ce, 1)},
{Ex: ce.Sub(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticSubOp, ce, 1)},
{Ex: ce.Mul(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticMulOp, ce, 1)},
{Ex: ce.Div(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticDivOp, ce, 1)},
}

for _, tc := range testCases {
Expand Down
52 changes: 52 additions & 0 deletions exp/exp.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ type (
// I("col").BitRighttShift(1) // ("col" >> 1)
BitwiseRightShift(interface{}) BitwiseExpression
}

Arithmeticable interface {
Add(interface{}) ArithmeticExpression
Sub(interface{}) ArithmeticExpression
Mul(interface{}) ArithmeticExpression
Div(interface{}) ArithmeticExpression
}
)

type (
Expand Down Expand Up @@ -223,6 +230,7 @@ type (
Expression
Aliaseable
Comparable
Arithmeticable
Isable
Inable
Likeable
Expand All @@ -242,6 +250,7 @@ type (
Expression
Aliaseable
Comparable
Arithmeticable
Inable
Isable
Likeable
Expand Down Expand Up @@ -310,6 +319,7 @@ type (
Expression
Aliaseable
Comparable
Arithmeticable
Inable
Isable
Likeable
Expand Down Expand Up @@ -383,6 +393,7 @@ type (
Expression
Aliaseable
Comparable
Arithmeticable
Isable
Inable
Likeable
Expand All @@ -395,6 +406,27 @@ type (
Args() []interface{}
}

ArithmeticOperation int
ArithmeticExpression interface {
Expression
Aliaseable
Comparable
Inable
Isable
Likeable
Rangeable
Orderable
Distinctable
Castable
Bitwiseable
// Returns the operator for the expression
Op() ArithmeticOperation
// The left hand side of the expression (e.g. I("a")
LHS() Expression
// The right hand side of the expression could be a primitive value, dataset, or expression
RHS() interface{}
}

NullSortType int
SortDirection int
// An expression for specifying sort order and options
Expand Down Expand Up @@ -438,6 +470,7 @@ type (
Aliaseable
Rangeable
Comparable
Arithmeticable
Orderable
Isable
Inable
Expand Down Expand Up @@ -598,6 +631,11 @@ const (
BitwiseXorOp
BitwiseLeftShiftOp
BitwiseRightShiftOp

ArithmeticAddOp ArithmeticOperation = iota
ArithmeticSubOp
ArithmeticMulOp
ArithmeticDivOp
)

var (
Expand Down Expand Up @@ -693,6 +731,20 @@ func (bi BitwiseOperation) String() string {
return fmt.Sprintf("%d", bi)
}

func (ao ArithmeticOperation) String() string {
switch ao {
case ArithmeticAddOp:
return "Addition"
case ArithmeticSubOp:
return "Subtraction"
case ArithmeticMulOp:
return "Multiplication"
case ArithmeticDivOp:
return "Division"
}
return fmt.Sprintf("%d", ao)
}

func (ro RangeOperation) String() string {
switch ro {
case BetweenOp:
Expand Down
13 changes: 13 additions & 0 deletions exp/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,16 @@ func (sfe sqlFunctionExpression) OverName(windowName IdentifierExpression) SQLWi

func (sfe sqlFunctionExpression) Asc() OrderedExpression { return asc(sfe) }
func (sfe sqlFunctionExpression) Desc() OrderedExpression { return desc(sfe) }

func (sfe sqlFunctionExpression) Add(val interface{}) ArithmeticExpression {
return arithmeticAdd(sfe, val)
}
func (sfe sqlFunctionExpression) Sub(val interface{}) ArithmeticExpression {
return arithmeticSub(sfe, val)
}
func (sfe sqlFunctionExpression) Mul(val interface{}) ArithmeticExpression {
return arithmeticMul(sfe, val)
}
func (sfe sqlFunctionExpression) Div(val interface{}) ArithmeticExpression {
return arithmeticDiv(sfe, val)
}
4 changes: 4 additions & 0 deletions exp/func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ func (sfes *sqlFunctionExpressionSuite) TestAllOthers() {
{Ex: fn.IsNotFalse(), Expected: exp.NewBooleanExpression(exp.IsNotOp, fn, false)},
{Ex: fn.Desc(), Expected: exp.NewOrderedExpression(fn, exp.DescSortDir, exp.NoNullsSortType)},
{Ex: fn.Asc(), Expected: exp.NewOrderedExpression(fn, exp.AscDir, exp.NoNullsSortType)},
{Ex: fn.Add(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticAddOp, fn, 1)},
{Ex: fn.Sub(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticSubOp, fn, 1)},
{Ex: fn.Mul(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticMulOp, fn, 1)},
{Ex: fn.Div(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticDivOp, fn, 1)},
}

for _, tc := range testCases {
Expand Down
5 changes: 5 additions & 0 deletions exp/ident.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,8 @@ func (i identifier) Between(val RangeVal) RangeExpression { return between(i, va

// Returns a RangeExpression for checking that a identifier is between two values (e.g "my_col" BETWEEN 1 AND 10)
func (i identifier) NotBetween(val RangeVal) RangeExpression { return notBetween(i, val) }

func (i identifier) Add(val interface{}) ArithmeticExpression { return arithmeticAdd(i, val) }
func (i identifier) Sub(val interface{}) ArithmeticExpression { return arithmeticSub(i, val) }
func (i identifier) Mul(val interface{}) ArithmeticExpression { return arithmeticMul(i, val) }
func (i identifier) Div(val interface{}) ArithmeticExpression { return arithmeticDiv(i, val) }
4 changes: 4 additions & 0 deletions exp/ident_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ func (ies *identifierExpressionSuite) TestAllOthers() {
{Ex: ident.BitwiseXor(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseXorOp, ident, bitwiseVals)},
{Ex: ident.BitwiseLeftShift(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseLeftShiftOp, ident, bitwiseVals)},
{Ex: ident.BitwiseRightShift(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseRightShiftOp, ident, bitwiseVals)},
{Ex: ident.Add(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticAddOp, ident, 1)},
{Ex: ident.Sub(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticSubOp, ident, 1)},
{Ex: ident.Mul(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticMulOp, ident, 1)},
{Ex: ident.Div(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticDivOp, ident, 1)},
}

for _, tc := range testCases {
Expand Down
4 changes: 4 additions & 0 deletions exp/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,7 @@ func (l literal) BitwiseLeftShift(val interface{}) BitwiseExpression { return bi
func (l literal) BitwiseRightShift(val interface{}) BitwiseExpression {
return bitwiseRightShift(l, val)
}
func (l literal) Add(val interface{}) ArithmeticExpression { return arithmeticAdd(l, val) }
func (l literal) Sub(val interface{}) ArithmeticExpression { return arithmeticSub(l, val) }
func (l literal) Mul(val interface{}) ArithmeticExpression { return arithmeticMul(l, val) }
func (l literal) Div(val interface{}) ArithmeticExpression { return arithmeticDiv(l, val) }
4 changes: 4 additions & 0 deletions exp/literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ func (les *literalExpressionSuite) TestAllOthers() {
{Ex: le.BitwiseXor(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseXorOp, le, bitwiseVals)},
{Ex: le.BitwiseLeftShift(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseLeftShiftOp, le, bitwiseVals)},
{Ex: le.BitwiseRightShift(bitwiseVals), Expected: exp.NewBitwiseExpression(exp.BitwiseRightShiftOp, le, bitwiseVals)},
{Ex: le.Add(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticAddOp, le, 1)},
{Ex: le.Sub(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticSubOp, le, 1)},
{Ex: le.Mul(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticMulOp, le, 1)},
{Ex: le.Div(1), Expected: exp.NewArithmeticExpression(exp.ArithmeticDivOp, le, 1)},
}

for _, tc := range testCases {
Expand Down
23 changes: 23 additions & 0 deletions sqlgen/expression_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func errUnsupportedBitwiseExpressionOperator(op exp.BitwiseOperation) error {
return errors.New("bitwise operator '%+v' not supported", op)
}

func errUnsupportedArithmeticExpressionOperator(op exp.ArithmeticOperation) error {
return errors.New("arithmetic operator '%+v' not supported", op)
}

func errUnsupportedRangeExpressionOperator(op exp.RangeOperation) error {
return errors.New("range operator %+v not supported", op)
}
Expand Down Expand Up @@ -176,6 +180,8 @@ func (esg *expressionSQLGenerator) expressionSQL(b sb.SQLBuilder, expression exp
esg.booleanExpressionSQL(b, e)
case exp.BitwiseExpression:
esg.bitwiseExpressionSQL(b, e)
case exp.ArithmeticExpression:
esg.arithmeticExpressionSQL(b, e)
case exp.RangeExpression:
esg.rangeExpressionSQL(b, e)
case exp.OrderedExpression:
Expand Down Expand Up @@ -448,6 +454,23 @@ func (esg *expressionSQLGenerator) bitwiseExpressionSQL(b sb.SQLBuilder, operato
b.WriteRunes(esg.dialectOptions.RightParenRune)
}

// Generates SQL for a ArithmeticExpresion (e.g. I("a").Add(1) -> "a" + 1)
func (esg *expressionSQLGenerator) arithmeticExpressionSQL(b sb.SQLBuilder, operator exp.ArithmeticExpression) {
b.WriteRunes(esg.dialectOptions.LeftParenRune)
esg.Generate(b, operator.LHS())
b.WriteRunes(esg.dialectOptions.SpaceRune)
operatorOp := operator.Op()
if val, ok := esg.dialectOptions.ArithmeticOperatorLookup[operatorOp]; ok {
b.Write(val)
} else {
b.SetError(errUnsupportedArithmeticExpressionOperator(operatorOp))
return
}
b.WriteRunes(esg.dialectOptions.SpaceRune)
esg.Generate(b, operator.RHS())
b.WriteRunes(esg.dialectOptions.RightParenRune)
}

// Generates SQL for a RangeExpresion (e.g. I("a").Between(RangeVal{Start:2,End:5}) -> "a" BETWEEN 2 AND 5)
func (esg *expressionSQLGenerator) rangeExpressionSQL(b sb.SQLBuilder, operator exp.RangeExpression) {
b.WriteRunes(esg.dialectOptions.LeftParenRune)
Expand Down
Loading

0 comments on commit 30ad302

Please sign in to comment.