Indicators: Add support for correlation coefficient (#519)

* Add support for correlation coefficient code

* Bump ta depends

* Bump gct-ta depends version
This commit is contained in:
Adrian Gallagher
2020-06-23 17:49:08 +10:00
committed by GitHub
parent 59edb47960
commit 9f775ca952
8 changed files with 162 additions and 59 deletions

View File

@@ -0,0 +1,15 @@
fmt := import("fmt")
exch := import("exchange")
t := import("times")
cc := import("indicator/correlationcoefficient")
load := func() {
start := t.date(2017, 8 , 17 , 0 , 0 , 0, 0)
end := t.add_date(start, 0, 6 , 0)
ohlcvDataBTC := exch.ohlcv("binance", "BTC-USDT", "-", "SPOT", start, end, "1d")
ohlcvDataETH := exch.ohlcv("binance", "ETH-USDT", "-", "SPOT", start, end, "1d")
ret := cc.calculate(ohlcvDataBTC.candles, ohlcvDataETH.candles, 20)
fmt.println(ret)
}
load()

View File

@@ -53,6 +53,8 @@ func WriteAsCSV(args ...objects.Object) (objects.Object, error) {
temp, err = convertRSI(args[i])
case indicators.SimpleMovingAverage:
temp, err = convertSMA(args[i])
case indicators.CorrelationCoefficient:
temp, err = convertCorrelationCoefficient(args[i])
case indicators.OHLCV:
temp, err = convertOHLCV(args[i])
front = true
@@ -387,6 +389,32 @@ func convertSMA(a objects.Object) ([][]string, error) {
return bucket, nil
}
func convertCorrelationCoefficient(a objects.Object) ([][]string, error) {
obj, ok := objects.ToInterface(a).(*indicators.Correlation)
if !ok {
return nil, errors.New("casting failure")
}
var bucket = [][]string{
{
indicators.CorrelationCoefficient,
},
{
fmt.Sprintf("Period:%d", obj.Period),
},
}
var val string
for i := range obj.Value {
val, ok = objects.ToString(obj.Value[i])
if !ok {
return nil, errors.New("cannot convert object to string")
}
bucket = append(bucket, []string{val})
}
return bucket, nil
}
func convertOHLCV(a objects.Object) ([][]string, error) {
obj, ok := objects.ToInterface(a).(*OHLCV)
if !ok {

View File

@@ -12,16 +12,17 @@ import (
)
var (
atrPayload = &indicators.ATR{Array: oneElement}
bbandsPayload = &indicators.BBands{Array: threeElement}
emaPayload = &indicators.EMA{Array: oneElement}
macdPayload = &indicators.MACD{Array: threeElement}
mfiPayload = &indicators.MFI{Array: oneElement}
obvPayload = &indicators.OBV{Array: oneElement}
rsiPayload = &indicators.RSI{Array: oneElement}
smaPayload = &indicators.SMA{Array: oneElement}
ohlcPayload = &OHLCV{Map: ohlcdata}
unhandled = &objects.Array{}
atrPayload = &indicators.ATR{Array: oneElement}
bbandsPayload = &indicators.BBands{Array: threeElement}
emaPayload = &indicators.EMA{Array: oneElement}
macdPayload = &indicators.MACD{Array: threeElement}
mfiPayload = &indicators.MFI{Array: oneElement}
obvPayload = &indicators.OBV{Array: oneElement}
rsiPayload = &indicators.RSI{Array: oneElement}
smaPayload = &indicators.SMA{Array: oneElement}
correlationPayload = &indicators.Correlation{Array: oneElement}
ohlcPayload = &OHLCV{Map: ohlcdata}
unhandled = &objects.Array{}
oneElement = objects.Array{
Value: []objects.Object{
@@ -157,6 +158,7 @@ func TestCommonWriteToCSV(t *testing.T) {
obvPayload,
rsiPayload,
smaPayload,
correlationPayload,
ohlcPayload)
if err != nil {
t.Fatal(err)

View File

@@ -0,0 +1,91 @@
package indicators
import (
"errors"
"fmt"
"math"
"strings"
objects "github.com/d5/tengo/v2"
"github.com/thrasher-corp/gct-ta/indicators"
"github.com/thrasher-corp/gocryptotrader/gctscript/modules"
"github.com/thrasher-corp/gocryptotrader/gctscript/wrappers/validator"
)
// CorrelationCoefficientModule indicator commands
var CorrelationCoefficientModule = map[string]objects.Object{
"calculate": &objects.UserFunction{Name: "calculate", Value: correlationCoefficient},
}
// CorrelationCoefficient is the string constant
const CorrelationCoefficient = "Correlation Coefficient"
// Correlation defines a custom correlation coefficient indicator tengo object
type Correlation struct {
objects.Array
Period int
}
// TypeName returns the name of the custom type.
func (c *Correlation) TypeName() string {
return CorrelationCoefficient
}
func correlationCoefficient(args ...objects.Object) (objects.Object, error) {
if len(args) != 3 {
return nil, objects.ErrWrongNumArguments
}
r := new(Correlation)
if validator.IsTestExecution.Load() == true {
return r, nil
}
var allErrors []string
ohlcvProcessor := func(args []objects.Object, idx int) ([]float64, error) {
ohlcvInput := objects.ToInterface(args[idx])
ohlcvInputData, valid := ohlcvInput.([]interface{})
if !valid {
return nil, fmt.Errorf(modules.ErrParameterConvertFailed, OHLCV)
}
var ohlcvClose []float64
for x := range ohlcvInputData {
t := ohlcvInputData[x].([]interface{})
value, err := toFloat64(t[4])
if err != nil {
allErrors = append(allErrors, err.Error())
}
ohlcvClose = append(ohlcvClose, value)
}
return ohlcvClose, nil
}
closures1, err := ohlcvProcessor(args, 0)
if err != nil {
return nil, err
}
closures2, err := ohlcvProcessor(args, 1)
if err != nil {
return nil, err
}
inTimePeriod, ok := objects.ToInt(args[2])
if !ok {
allErrors = append(allErrors, fmt.Sprintf(modules.ErrParameterConvertFailed, inTimePeriod))
}
if len(allErrors) > 0 {
return nil, errors.New(strings.Join(allErrors, ", "))
}
r.Period = inTimePeriod
ret := indicators.CorrelationCoefficient(closures1, closures2, inTimePeriod)
for x := range ret {
r.Value = append(r.Value, &objects.Float{Value: math.Round(ret[x]*100) / 100})
}
return r, nil
}

View File

@@ -11,7 +11,7 @@ func TestGetModuleMap(t *testing.T) {
if xType != reflect.Slice {
t.Fatalf("AllModuleNames() should return slice instead received: %v", x)
}
if len(x) != 8 {
t.Fatalf("unexpected results received expected 7 received: %v", len(x))
if len(x) != 9 {
t.Fatalf("unexpected results received expected 9 received: %v", len(x))
}
}

View File

@@ -7,12 +7,13 @@ import (
// Modules map of all loadable modules
var Modules = map[string]map[string]tengo.Object{
"indicator/bbands": indicators.BBandsModule,
"indicator/macd": indicators.MACDModule,
"indicator/ema": indicators.EMAModule,
"indicator/sma": indicators.SMAModule,
"indicator/rsi": indicators.RsiModule,
"indicator/obv": indicators.ObvModule,
"indicator/mfi": indicators.MfiModule,
"indicator/atr": indicators.AtrModule,
"indicator/bbands": indicators.BBandsModule,
"indicator/macd": indicators.MACDModule,
"indicator/ema": indicators.EMAModule,
"indicator/sma": indicators.SMAModule,
"indicator/rsi": indicators.RsiModule,
"indicator/obv": indicators.ObvModule,
"indicator/mfi": indicators.MfiModule,
"indicator/atr": indicators.AtrModule,
"indicator/correlationcoefficient": indicators.CorrelationCoefficientModule,
}