mirror of
https://github.com/d0zingcat/gocryptotrader.git
synced 2026-05-13 15:09:42 +00:00
Database interface & auditing feature (#332)
* added audit manager * Basic database DOA setup * Added base config file * added sqlite support and creation of schema * added basic tests and config entry * corrected issues of database is disabled * fixed path for test * WIP * Added tests fixed config checking * reverted files back to upstream * reverted go.mod files * no more test test test * removed local testing details for psql * hello * added comments * increased ping to 30 seconds * renamed database table and added additional condition around test * removed database test details * goimport ran on all files * WIP * first attempt at migration * fixes for migration system * Migration system logger interface implemented * fixes to print functions * added write pooling pass * gofmt :D * formatted imports correctly * removed old code * added creation of migration * gofmt * :D Hello * ❌ 🏎️ * maybe one day i will remember to revert go mod files * checked err return condition correctly * first changes for PR feedback * code clean up * protect Connected with RWmutex & event with mutex * : D * we can just pretend like it never happened * MOved migrations back to source directory and added README * readme formatting update * Addd command line override for datadir * use correct var when creating a migration and confirm folder is created * Check if database version is newer than latest migration and also you know make migrations work..... * uses filepath instead of manual path to use correct path seperator * Add connection message and lower timeout * Added support for sslmode for psql * no longer force Close of database instead allow driver to maage * Added closer func to test output * sslmode added to example config
This commit is contained in:
76
database/README.md
Normal file
76
database/README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# GoCryptoTrader package Database
|
||||
|
||||
<img src="https://github.com/thrasher-corp/gocryptotrader/blob/master/web/src/assets/page-logo.png?raw=true" width="350px" height="350px" hspace="70">
|
||||
|
||||
|
||||
[](https://travis-ci.org/thrasher-corp/gocryptotrader)
|
||||
[](https://github.com/thrasher-corp/gocryptotrader/blob/master/LICENSE)
|
||||
[](https://godoc.org/github.com/thrasher-corp/gocryptotrader/portfolio)
|
||||
[](http://codecov.io/github/thrasher-corp/gocryptotrader?branch=master)
|
||||
[](https://goreportcard.com/report/github.com/thrasher-corp/gocryptotrader)
|
||||
|
||||
|
||||
This database package is part of the GoCryptoTrader codebase.
|
||||
|
||||
## This is still in active development
|
||||
|
||||
You can track ideas, planned features and what's in progresss on this Trello board: [https://trello.com/b/ZAhMhpOy/gocryptotrader](https://trello.com/b/ZAhMhpOy/gocryptotrader).
|
||||
|
||||
Join our slack to discuss all things related to GoCryptoTrader! [GoCryptoTrader Slack](https://join.slack.com/t/gocryptotrader/shared_invite/enQtNTQ5NDAxMjA2Mjc5LTQyYjIxNGVhMWU5MDZlOGYzMmE0NTJmM2MzYWY5NGMzMmM4MzUwNTBjZTEzNjIwODM5NDcxODQwZDljMGQyNGY)
|
||||
|
||||
## Current Features for database package
|
||||
|
||||
+ Establishes & Maintains database connection across program life cycle
|
||||
+ Multiple database support via simple repository model
|
||||
+ Run migration on connection to assure database is at correct version
|
||||
|
||||
## How to use
|
||||
|
||||
##### To Manually migrate to the latest database you can run the "dbmigrate" helper in the cmd folder
|
||||
|
||||
This will parse and run all migration files in your $GoCryptoTrader/database/migrations
|
||||
|
||||
_This is also run from the bot when a connection is established to the database_
|
||||
|
||||
```sh
|
||||
go run ./cmd/dbmigrate
|
||||
```
|
||||
A Makefile command has also been added for this
|
||||
```sh
|
||||
make db_migrate
|
||||
```
|
||||
|
||||
##### To create a new migrate file you can also run the same command with the -create "migration name" flag
|
||||
|
||||
```sh
|
||||
go run ./cmd/dbmigrate -create "alter some table"
|
||||
```
|
||||
|
||||
##### Adding a new model
|
||||
|
||||
+ Create Model in github.com/thrasher-corp/gocryptotrader/database/models directory
|
||||
|
||||
##### Adding a Repository
|
||||
+ Create Repository directory in github.com/thrasher-corp/gocryptotrader/database/repository/
|
||||
+ Create a base Repository interface with any required Methods
|
||||
+ Create a per driver implementation of the Repository that implement all required methods to match the interface
|
||||
|
||||
## Contribution
|
||||
|
||||
Please feel free to submit any pull requests or suggest any desired features to be added.
|
||||
|
||||
When submitting a PR, please abide by our coding guidelines:
|
||||
|
||||
+ Code must adhere to the official Go [formatting](https://golang.org/doc/effective_go.html#formatting) guidelines (i.e. uses [gofmt](https://golang.org/cmd/gofmt/)).
|
||||
+ Code must be documented adhering to the official Go [commentary](https://golang.org/doc/effective_go.html#commentary) guidelines.
|
||||
+ Code must adhere to our [coding style](https://github.com/thrasher-corp/gocryptotrader/blob/master/doc/coding_style.md).
|
||||
+ Pull requests need to be based on and opened against the `master` branch.
|
||||
|
||||
## Donations
|
||||
|
||||
<img src="https://github.com/thrasher-corp/gocryptotrader/blob/master/web/src/assets/donate.png?raw=true" hspace="70">
|
||||
|
||||
If this framework helped you in any way, or you would like to support the developers working on it, please donate Bitcoin to:
|
||||
|
||||
***1F5zVDgNjorJ51oGebSvNCrSAHpwGkUdDB***
|
||||
|
||||
38
database/db_types.go
Normal file
38
database/db_types.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/drivers"
|
||||
)
|
||||
|
||||
// Database holds a pointer to sql connection, DataPath which is used for file based databases
|
||||
// and a pointer to a Config struct
|
||||
type Database struct {
|
||||
Config *Config
|
||||
DataPath string
|
||||
SQL *sqlx.DB
|
||||
|
||||
Connected bool
|
||||
Mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Config holds connection information about the database what the driver type is and if its enabled or not
|
||||
type Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Driver string `json:"driver"`
|
||||
drivers.ConnectionDetails `json:"connectionDetails"`
|
||||
}
|
||||
|
||||
// Conn is a global copy of Database{} struct
|
||||
var Conn = &Database{}
|
||||
|
||||
var (
|
||||
// ErrNoDatabaseProvided error to display when no database is provided
|
||||
ErrNoDatabaseProvided = errors.New("no database provided")
|
||||
|
||||
// SupportedDrivers slice of supported database driver types
|
||||
SupportedDrivers = []string{"sqlite", "postgres"}
|
||||
)
|
||||
11
database/drivers/drivers_type.go
Normal file
11
database/drivers/drivers_type.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package drivers
|
||||
|
||||
// ConnectionDetails holds DSN information
|
||||
type ConnectionDetails struct {
|
||||
Host string
|
||||
Port uint16
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
SSLMode string
|
||||
}
|
||||
41
database/drivers/postgres/postgresql.go
Normal file
41
database/drivers/postgres/postgresql.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/stdlib"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/thrasher-corp/gocryptotrader/database"
|
||||
)
|
||||
|
||||
// Connect establishes a connection pool to the database
|
||||
func Connect() (*database.Database, error) {
|
||||
configDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s database=%s sslmode=%s",
|
||||
database.Conn.Config.Host,
|
||||
database.Conn.Config.Port,
|
||||
database.Conn.Config.Username,
|
||||
database.Conn.Config.Password,
|
||||
database.Conn.Config.Database,
|
||||
database.Conn.Config.SSLMode)
|
||||
|
||||
connConfig, err := pgx.ParseDSN(configDSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connPool, err := pgx.NewConnPool(pgx.ConnPoolConfig{
|
||||
ConnConfig: connConfig,
|
||||
AfterConnect: nil,
|
||||
MaxConnections: 20,
|
||||
AcquireTimeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlxDB := stdlib.OpenDBFromPool(connPool)
|
||||
database.Conn.SQL = sqlx.NewDb(sqlxDB, "pgx")
|
||||
return database.Conn, nil
|
||||
}
|
||||
28
database/drivers/sqlite/sqlite.go
Normal file
28
database/drivers/sqlite/sqlite.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
// import sqlite3 driver
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/thrasher-corp/gocryptotrader/database"
|
||||
)
|
||||
|
||||
// Connect creates a connection to the entered database
|
||||
// With SQLite the database is not created until first read/write
|
||||
|
||||
func Connect() (*database.Database, error) {
|
||||
if database.Conn.Config.Database == "" {
|
||||
return nil, database.ErrNoDatabaseProvided
|
||||
}
|
||||
|
||||
databaseFullLocation := filepath.Join(database.Conn.DataPath, database.Conn.Config.Database)
|
||||
dbConn, err := sqlx.Open("sqlite3", databaseFullLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database.Conn.SQL = dbConn
|
||||
return database.Conn, nil
|
||||
}
|
||||
180
database/migration/migrate.go
Normal file
180
database/migration/migrate.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// LoadMigrations will load all migrations in the ./database/migration/migrations folder
|
||||
func (m *Migrator) LoadMigrations() error {
|
||||
flag.Visit(func(f *flag.Flag) {
|
||||
if f.Name == "migrationdir" {
|
||||
MigrationDir = flag.Lookup("migrationdir").Value.String()
|
||||
}
|
||||
})
|
||||
|
||||
m.Log.Printf("Using migration folder %s\n", MigrationDir)
|
||||
|
||||
migration, err := filepath.Glob(MigrationDir + "/*.sql")
|
||||
|
||||
if err != nil {
|
||||
return errors.New("failed to load migrations")
|
||||
}
|
||||
|
||||
if len(migration) == 0 {
|
||||
return errors.New("no migration files found")
|
||||
}
|
||||
|
||||
sort.Strings(migration)
|
||||
|
||||
for x := range migration {
|
||||
err = m.loadMigration(migration[x])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) loadMigration(migration string) error {
|
||||
file, err := os.Open(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileData := strings.Trim(file.Name(), MigrationDir)
|
||||
fileSeq := strings.Split(fileData, "_")
|
||||
seq, _ := strconv.Atoi(fileSeq[0])
|
||||
|
||||
b, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
up := bytes.Split(b, []byte("-- up"))
|
||||
|
||||
if len(up) == 1 {
|
||||
return fmt.Errorf("invalid migration file %v", file.Name())
|
||||
}
|
||||
|
||||
down := strings.Split(string(up[1]), "-- down")
|
||||
|
||||
temp := Migration{
|
||||
Sequence: seq,
|
||||
UpSQL: down[0],
|
||||
DownSQL: down[1],
|
||||
}
|
||||
|
||||
m.Migrations = append(m.Migrations, temp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunMigration attempts to run current migrations against a database
|
||||
func (m *Migrator) RunMigration() (err error) {
|
||||
v, err := m.getCurrentVersion()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
m.Log.Printf("Current database version: %v\n", v)
|
||||
|
||||
latestSeq := m.Migrations[len(m.Migrations)-1].Sequence
|
||||
|
||||
if v > latestSeq {
|
||||
return errors.New("current database version is greater than latest migration halting further migrations")
|
||||
}
|
||||
|
||||
if v == latestSeq {
|
||||
m.Log.Println("no migrations to be run")
|
||||
return
|
||||
}
|
||||
|
||||
tx, err := m.Conn.SQL.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for y := 0; y < len(m.Migrations); y++ {
|
||||
if m.Migrations[y].Sequence <= v {
|
||||
continue
|
||||
}
|
||||
|
||||
err = m.txBegin(tx, m.checkConvert(m.Migrations[y].UpSQL))
|
||||
if err != nil {
|
||||
return tx.Rollback()
|
||||
}
|
||||
|
||||
_, err = tx.Exec("update version set version=$1", m.Migrations[y].Sequence)
|
||||
if err != nil {
|
||||
return tx.Rollback()
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return tx.Rollback()
|
||||
}
|
||||
|
||||
m.Log.Println("Migration completed")
|
||||
m.Log.Printf("New database version: %v\n", latestSeq)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) txBegin(tx *sql.Tx, input string) error {
|
||||
_, err := tx.Exec(input)
|
||||
if err != nil {
|
||||
m.Log.Errorf("%v", err)
|
||||
return tx.Rollback()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) getCurrentVersion() (v int, err error) {
|
||||
err = m.checkVersionTableExists()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = m.Conn.SQL.QueryRow("select version from version").Scan(&v)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Migrator) checkVersionTableExists() error {
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS version(
|
||||
version int not null
|
||||
);
|
||||
|
||||
INSERT INTO version SELECT 0 WHERE 0=(SELECT COUNT(*) from version);
|
||||
`
|
||||
|
||||
_, err := m.Conn.SQL.Exec(m.checkConvert(query))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) checkConvert(input string) string {
|
||||
if m.Conn.Config.Driver != "sqlite" {
|
||||
return input
|
||||
}
|
||||
|
||||
// Common PSQL -> SQLITE conversion
|
||||
// TODO: Find a better way to handle this list
|
||||
|
||||
r := strings.NewReplacer(
|
||||
"bigserial", "integer",
|
||||
"int", "integer",
|
||||
"now()", "CURRENT_TIMESTAMP")
|
||||
|
||||
return r.Replace(input)
|
||||
}
|
||||
37
database/migration/migrate_type.go
Normal file
37
database/migration/migrate_type.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/database"
|
||||
)
|
||||
|
||||
var (
|
||||
// MigrationDir Default folder to look for migrations to apply
|
||||
MigrationDir = filepath.Join("./database", "migration", "migrations")
|
||||
)
|
||||
|
||||
// Migration holds all information passes from a migration file
|
||||
// Includes: Sequence(version), SQL queries to run on up & down
|
||||
type Migration struct {
|
||||
Sequence int
|
||||
Name string
|
||||
UpSQL string
|
||||
DownSQL string
|
||||
}
|
||||
|
||||
// Migrator holds pointer to database struct slice of Migrations and logger
|
||||
type Migrator struct {
|
||||
Conn *database.Database
|
||||
Migrations []Migration
|
||||
Log Logger
|
||||
}
|
||||
|
||||
// Logger interface implementation
|
||||
// Allows you to BYO Logging/Printing
|
||||
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
Println(v ...interface{})
|
||||
Errorf(format string, v ...interface{})
|
||||
}
|
||||
25
database/migration/migration_logger.go
Normal file
25
database/migration/migration_logger.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
log "github.com/thrasher-corp/gocryptotrader/logger"
|
||||
)
|
||||
|
||||
type MLogger struct{}
|
||||
|
||||
// Printf implantation of migration Logger interface
|
||||
// Passes off to log.Infof
|
||||
func (t MLogger) Printf(format string, v ...interface{}) {
|
||||
log.Infof(log.DatabaseMgr, format, v...)
|
||||
}
|
||||
|
||||
// Println implantation of migration Logger interface
|
||||
// Passes off to log.Infoln
|
||||
func (t MLogger) Println(v ...interface{}) {
|
||||
log.Infoln(log.DatabaseMgr, v...)
|
||||
}
|
||||
|
||||
// Errorf implantation of migration Logger interface
|
||||
// Passes off to log.Errorf
|
||||
func (t MLogger) Errorf(format string, v ...interface{}) {
|
||||
log.Errorf(log.DatabaseMgr, format, v...)
|
||||
}
|
||||
11
database/migration/migrations/1565657999_create_audit_event_table.sql
Executable file
11
database/migration/migrations/1565657999_create_audit_event_table.sql
Executable file
@@ -0,0 +1,11 @@
|
||||
-- up
|
||||
CREATE TABLE IF NOT EXISTS audit_event
|
||||
(
|
||||
id bigserial PRIMARY KEY NOT NULL,
|
||||
Type varchar(255) NOT NULL,
|
||||
Identifier varchar(255) NOT NULL,
|
||||
Message text NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT now()
|
||||
);
|
||||
-- down
|
||||
DROP TABLE audit_event;
|
||||
8
database/models/audit.go
Normal file
8
database/models/audit.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package models
|
||||
|
||||
// AuditEvent is a model of how the data is represented in a database
|
||||
type AuditEvent struct {
|
||||
Type string
|
||||
Identifier string
|
||||
Message string
|
||||
}
|
||||
62
database/repository/audit/audit.go
Normal file
62
database/repository/audit/audit.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/database"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/models"
|
||||
log "github.com/thrasher-corp/gocryptotrader/logger"
|
||||
)
|
||||
|
||||
// Repository that is required for each driver type to implement
|
||||
type Repository interface {
|
||||
AddEventTx(event []*models.AuditEvent)
|
||||
}
|
||||
|
||||
var (
|
||||
// Audit repository initialise copy of Audit Repository
|
||||
Audit Repository
|
||||
)
|
||||
|
||||
type eventPool struct {
|
||||
events []*models.AuditEvent
|
||||
eventMu sync.Mutex
|
||||
}
|
||||
|
||||
var ep eventPool
|
||||
|
||||
// Event allows you to call audit.Event() as long as the audit repository package without the need to include each driver
|
||||
func Event(msgType, identifier, message string) {
|
||||
if database.Conn.SQL == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if Audit == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tempEvent := models.AuditEvent{
|
||||
Type: msgType,
|
||||
Identifier: identifier,
|
||||
Message: message}
|
||||
|
||||
ep.poolEvents(&tempEvent)
|
||||
}
|
||||
|
||||
func (e *eventPool) poolEvents(event *models.AuditEvent) {
|
||||
e.eventMu.Lock()
|
||||
defer e.eventMu.Unlock()
|
||||
|
||||
e.events = append(e.events, event)
|
||||
|
||||
database.Conn.Mu.RLock()
|
||||
defer database.Conn.Mu.RUnlock()
|
||||
|
||||
if !database.Conn.Connected {
|
||||
log.Warnln(log.DatabaseMgr, "connection to database interrupted pooling database writes")
|
||||
return
|
||||
}
|
||||
|
||||
Audit.AddEventTx(e.events)
|
||||
e.events = nil
|
||||
}
|
||||
52
database/repository/audit/postgres/audit.go
Normal file
52
database/repository/audit/postgres/audit.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"github.com/thrasher-corp/gocryptotrader/database"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/models"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/repository/audit"
|
||||
log "github.com/thrasher-corp/gocryptotrader/logger"
|
||||
)
|
||||
|
||||
type auditRepo struct{}
|
||||
|
||||
// Audit returns a new instance of auditRepo
|
||||
func Audit() audit.Repository {
|
||||
return &auditRepo{}
|
||||
}
|
||||
|
||||
// AddEventTx writes multiple events to database
|
||||
// writes are done using a transaction with a rollback on error
|
||||
func (pg *auditRepo) AddEventTx(event []*models.AuditEvent) {
|
||||
if pg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tx, err := database.Conn.SQL.Begin()
|
||||
if err != nil {
|
||||
log.Errorf(log.Global, "Failed to create transaction: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
query := `INSERT INTO audit_event (type, identifier, message) VALUES($1, $2, $3)`
|
||||
|
||||
for x := range event {
|
||||
_, err = tx.Exec(query, &event[x].Type, &event[x].Identifier, &event[x].Message)
|
||||
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Errorf(log.Global, "Tx Rollback has failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Errorf(log.Global, "Tx Rollback has failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
53
database/repository/audit/sqlite/audit.go
Normal file
53
database/repository/audit/sqlite/audit.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"github.com/thrasher-corp/gocryptotrader/database"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/models"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/repository/audit"
|
||||
log "github.com/thrasher-corp/gocryptotrader/logger"
|
||||
)
|
||||
|
||||
type auditRepo struct{}
|
||||
|
||||
// Audit returns a new instance of auditRepo
|
||||
func Audit() audit.Repository {
|
||||
return &auditRepo{}
|
||||
}
|
||||
|
||||
// AddEventTx writes multiple event to database
|
||||
// writes are done using a transaction with a rollback on error
|
||||
func (pg *auditRepo) AddEventTx(event []*models.AuditEvent) {
|
||||
if pg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tx, err := database.Conn.SQL.Begin()
|
||||
if err != nil {
|
||||
log.Errorf(log.Global, "Failed to create transaction: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
query := `INSERT INTO audit_event (type, identifier, message) VALUES($1, $2, $3)`
|
||||
|
||||
for x := range event {
|
||||
_, err = tx.Exec(query, &event[x].Type, &event[x].Identifier, &event[x].Message)
|
||||
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Errorf(log.Global, "Tx Rollback has failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
err = tx.Rollback()
|
||||
if err != nil {
|
||||
log.Errorf(log.Global, "Tx Rollback has failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
126
database/tests/audit_test.go
Normal file
126
database/tests/audit_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/database"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/drivers"
|
||||
mg "github.com/thrasher-corp/gocryptotrader/database/migration"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/repository/audit"
|
||||
auditPSQL "github.com/thrasher-corp/gocryptotrader/database/repository/audit/postgres"
|
||||
auditSQlite "github.com/thrasher-corp/gocryptotrader/database/repository/audit/sqlite"
|
||||
)
|
||||
|
||||
func TestAudit(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config database.Config
|
||||
audit audit.Repository
|
||||
runner func(t *testing.T)
|
||||
closer func(t *testing.T, dbConn *database.Database) error
|
||||
output interface{}
|
||||
}{
|
||||
{
|
||||
"SQLite",
|
||||
database.Config{
|
||||
Driver: "sqlite",
|
||||
ConnectionDetails: drivers.ConnectionDetails{Database: path.Join(tempDir, "./testdb.db")},
|
||||
},
|
||||
auditSQlite.Audit(),
|
||||
writeAudit,
|
||||
closeDatabase,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Postgres",
|
||||
postgresTestDatabase,
|
||||
auditPSQL.Audit(),
|
||||
writeAudit,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tests := range testCases {
|
||||
test := tests
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
|
||||
mg.MigrationDir = filepath.Join("../migration", "migrations")
|
||||
|
||||
if !checkValidConfig(t, &test.config.ConnectionDetails) {
|
||||
t.Skip("database not configured skipping test")
|
||||
}
|
||||
|
||||
dbConn, err := connectToDatabase(t, &test.config)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mLogger := mg.MLogger{}
|
||||
migrations := mg.Migrator{
|
||||
Log: mLogger,
|
||||
}
|
||||
|
||||
migrations.Conn = dbConn
|
||||
|
||||
err = migrations.LoadMigrations()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = migrations.RunMigration()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test.audit != nil {
|
||||
audit.Audit = test.audit
|
||||
}
|
||||
|
||||
if test.runner != nil {
|
||||
test.runner(t)
|
||||
}
|
||||
|
||||
switch v := test.output.(type) {
|
||||
|
||||
case error:
|
||||
if v.Error() != test.output.(error).Error() {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
if test.closer != nil {
|
||||
err = test.closer(t, dbConn)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func writeAudit(t *testing.T) {
|
||||
t.Helper()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for x := 0; x < 200; x++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(x int) {
|
||||
defer wg.Done()
|
||||
test := fmt.Sprintf("test-%v", x)
|
||||
audit.Event(test, test, test)
|
||||
}(x)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
148
database/tests/db_test.go
Normal file
148
database/tests/db_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/thrasher-corp/gocryptotrader/database"
|
||||
"github.com/thrasher-corp/gocryptotrader/database/drivers"
|
||||
dbpsql "github.com/thrasher-corp/gocryptotrader/database/drivers/postgres"
|
||||
dbsqlite "github.com/thrasher-corp/gocryptotrader/database/drivers/sqlite"
|
||||
)
|
||||
|
||||
var (
|
||||
tempDir string
|
||||
|
||||
postgresTestDatabase = database.Config{
|
||||
Enabled: true,
|
||||
Driver: "postgres",
|
||||
ConnectionDetails: drivers.ConnectionDetails{
|
||||
//Host: "",
|
||||
//Port: 5432,
|
||||
//Username: "",
|
||||
//Password: "",
|
||||
//Database: "",
|
||||
//SSLMode: "",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
var err error
|
||||
tempDir, err = ioutil.TempDir("", "gct-temp")
|
||||
if err != nil {
|
||||
fmt.Printf("failed to create temp file: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
t := m.Run()
|
||||
|
||||
err = os.RemoveAll(tempDir)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to remove temp db file: %v", err)
|
||||
}
|
||||
|
||||
os.Exit(t)
|
||||
}
|
||||
|
||||
func TestDatabaseConnect(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config database.Config
|
||||
closer func(t *testing.T, dbConn *database.Database) error
|
||||
output interface{}
|
||||
}{
|
||||
{
|
||||
"SQLite",
|
||||
database.Config{
|
||||
Driver: "sqlite",
|
||||
ConnectionDetails: drivers.ConnectionDetails{Database: path.Join(tempDir, "./testdb.db")},
|
||||
},
|
||||
closeDatabase,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"SQliteNoDatabase",
|
||||
database.Config{
|
||||
Driver: "sqlite",
|
||||
ConnectionDetails: drivers.ConnectionDetails{
|
||||
Host: "localhost",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
database.ErrNoDatabaseProvided,
|
||||
},
|
||||
{
|
||||
name: "Postgres",
|
||||
config: postgresTestDatabase,
|
||||
output: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tests := range testCases {
|
||||
test := tests
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if !checkValidConfig(t, &test.config.ConnectionDetails) {
|
||||
t.Skip("database not configured skipping test")
|
||||
}
|
||||
|
||||
dbConn, err := connectToDatabase(t, &test.config)
|
||||
if err != nil {
|
||||
switch v := test.output.(type) {
|
||||
case error:
|
||||
if v.Error() != err.Error() {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if test.closer != nil {
|
||||
err = test.closer(t, dbConn)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func connectToDatabase(t *testing.T, conn *database.Config) (dbConn *database.Database, err error) {
|
||||
t.Helper()
|
||||
database.Conn.Config = conn
|
||||
|
||||
if conn.Driver == "postgres" {
|
||||
dbConn, err = dbpsql.Connect()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if conn.Driver == "sqlite" {
|
||||
dbConn, err = dbsqlite.Connect()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
database.Conn.Connected = true
|
||||
return
|
||||
}
|
||||
|
||||
func closeDatabase(t *testing.T, conn *database.Database) (err error) {
|
||||
t.Helper()
|
||||
|
||||
if conn != nil {
|
||||
return conn.SQL.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkValidConfig(t *testing.T, config *drivers.ConnectionDetails) bool {
|
||||
t.Helper()
|
||||
|
||||
return !reflect.DeepEqual(drivers.ConnectionDetails{}, *config)
|
||||
}
|
||||
Reference in New Issue
Block a user