You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							305 lines
						
					
					
						
							6.6 KiB
						
					
					
				
			
		
		
	
	
							305 lines
						
					
					
						
							6.6 KiB
						
					
					
				| package testfixtures
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"io/ioutil"
 | |
| 	"path"
 | |
| 	"path/filepath"
 | |
| 	"regexp"
 | |
| 	"strings"
 | |
| 
 | |
| 	"gopkg.in/yaml.v2"
 | |
| )
 | |
| 
 | |
| // Context holds the fixtures to be loaded in the database.
 | |
| type Context struct {
 | |
| 	db            *sql.DB
 | |
| 	helper        Helper
 | |
| 	fixturesFiles []*fixtureFile
 | |
| }
 | |
| 
 | |
| type fixtureFile struct {
 | |
| 	path       string
 | |
| 	fileName   string
 | |
| 	content    []byte
 | |
| 	insertSQLs []insertSQL
 | |
| }
 | |
| 
 | |
| type insertSQL struct {
 | |
| 	sql    string
 | |
| 	params []interface{}
 | |
| }
 | |
| 
 | |
| var (
 | |
| 	dbnameRegexp = regexp.MustCompile("(?i)test")
 | |
| )
 | |
| 
 | |
| // NewFolder creates a context for all fixtures in a given folder into the database:
 | |
| //     NewFolder(db, &PostgreSQL{}, "my/fixtures/folder")
 | |
| func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
 | |
| 	fixtures, err := fixturesFromFolder(folderName)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	c, err := newContext(db, helper, fixtures)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return c, nil
 | |
| }
 | |
| 
 | |
| // NewFiles creates a context for all specified fixtures files into database:
 | |
| //     NewFiles(db, &PostgreSQL{},
 | |
| //         "fixtures/customers.yml",
 | |
| //         "fixtures/orders.yml"
 | |
| //         // add as many files you want
 | |
| //     )
 | |
| func NewFiles(db *sql.DB, helper Helper, fileNames ...string) (*Context, error) {
 | |
| 	fixtures, err := fixturesFromFiles(fileNames...)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	c, err := newContext(db, helper, fixtures)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return c, nil
 | |
| }
 | |
| 
 | |
| func newContext(db *sql.DB, helper Helper, fixtures []*fixtureFile) (*Context, error) {
 | |
| 	c := &Context{
 | |
| 		db:            db,
 | |
| 		helper:        helper,
 | |
| 		fixturesFiles: fixtures,
 | |
| 	}
 | |
| 
 | |
| 	if err := c.helper.init(c.db); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if err := c.buildInsertSQLs(); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return c, nil
 | |
| }
 | |
| 
 | |
| // DetectTestDatabase returns nil if databaseName matches regexp
 | |
| //     if err := fixtures.DetectTestDatabase(); err != nil {
 | |
| //         log.Fatal(err)
 | |
| //     }
 | |
| func (c *Context) DetectTestDatabase() error {
 | |
| 	dbName, err := c.helper.databaseName(c.db)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if !dbnameRegexp.MatchString(dbName) {
 | |
| 		return ErrNotTestDatabase
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Load wipes and after load all fixtures in the database.
 | |
| //     if err := fixtures.Load(); err != nil {
 | |
| //         log.Fatal(err)
 | |
| //     }
 | |
| func (c *Context) Load() error {
 | |
| 	if !skipDatabaseNameCheck {
 | |
| 		if err := c.DetectTestDatabase(); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	err := c.helper.disableReferentialIntegrity(c.db, func(tx *sql.Tx) error {
 | |
| 		for _, file := range c.fixturesFiles {
 | |
| 			modified, err := c.helper.isTableModified(tx, file.fileNameWithoutExtension())
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			if !modified {
 | |
| 				continue
 | |
| 			}
 | |
| 			if err := file.delete(tx, c.helper); err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 
 | |
| 			err = c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
 | |
| 				for j, i := range file.insertSQLs {
 | |
| 					if _, err := tx.Exec(i.sql, i.params...); err != nil {
 | |
| 						return &InsertError{
 | |
| 							Err:    err,
 | |
| 							File:   file.fileName,
 | |
| 							Index:  j,
 | |
| 							SQL:    i.sql,
 | |
| 							Params: i.params,
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 				return nil
 | |
| 			})
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 		}
 | |
| 		return nil
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return c.helper.afterLoad(c.db)
 | |
| }
 | |
| 
 | |
| func (c *Context) buildInsertSQLs() error {
 | |
| 	for _, f := range c.fixturesFiles {
 | |
| 		var records interface{}
 | |
| 		if err := yaml.Unmarshal(f.content, &records); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		switch records := records.(type) {
 | |
| 		case []interface{}:
 | |
| 			for _, record := range records {
 | |
| 				recordMap, ok := record.(map[interface{}]interface{})
 | |
| 				if !ok {
 | |
| 					return ErrWrongCastNotAMap
 | |
| 				}
 | |
| 
 | |
| 				sql, values, err := f.buildInsertSQL(c.helper, recordMap)
 | |
| 				if err != nil {
 | |
| 					return err
 | |
| 				}
 | |
| 
 | |
| 				f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
 | |
| 			}
 | |
| 		case map[interface{}]interface{}:
 | |
| 			for _, record := range records {
 | |
| 				recordMap, ok := record.(map[interface{}]interface{})
 | |
| 				if !ok {
 | |
| 					return ErrWrongCastNotAMap
 | |
| 				}
 | |
| 
 | |
| 				sql, values, err := f.buildInsertSQL(c.helper, recordMap)
 | |
| 				if err != nil {
 | |
| 					return err
 | |
| 				}
 | |
| 
 | |
| 				f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
 | |
| 			}
 | |
| 		default:
 | |
| 			return ErrFileIsNotSliceOrMap
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (f *fixtureFile) fileNameWithoutExtension() string {
 | |
| 	return strings.Replace(f.fileName, filepath.Ext(f.fileName), "", 1)
 | |
| }
 | |
| 
 | |
| func (f *fixtureFile) delete(tx *sql.Tx, h Helper) error {
 | |
| 	_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", h.quoteKeyword(f.fileNameWithoutExtension())))
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (f *fixtureFile) buildInsertSQL(h Helper, record map[interface{}]interface{}) (sqlStr string, values []interface{}, err error) {
 | |
| 	var (
 | |
| 		sqlColumns []string
 | |
| 		sqlValues  []string
 | |
| 		i          = 1
 | |
| 	)
 | |
| 	for key, value := range record {
 | |
| 		keyStr, ok := key.(string)
 | |
| 		if !ok {
 | |
| 			err = ErrKeyIsNotString
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		sqlColumns = append(sqlColumns, h.quoteKeyword(keyStr))
 | |
| 
 | |
| 		// if string, try convert to SQL or time
 | |
| 		// if map or array, convert to json
 | |
| 		switch v := value.(type) {
 | |
| 		case string:
 | |
| 			if strings.HasPrefix(v, "RAW=") {
 | |
| 				sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			if t, err := tryStrToDate(v); err == nil {
 | |
| 				value = t
 | |
| 			}
 | |
| 		case []interface{}, map[interface{}]interface{}:
 | |
| 			value = recursiveToJSON(v)
 | |
| 		}
 | |
| 
 | |
| 		switch h.paramType() {
 | |
| 		case paramTypeDollar:
 | |
| 			sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
 | |
| 		case paramTypeQuestion:
 | |
| 			sqlValues = append(sqlValues, "?")
 | |
| 		case paramTypeColon:
 | |
| 			sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
 | |
| 		}
 | |
| 
 | |
| 		values = append(values, value)
 | |
| 		i++
 | |
| 	}
 | |
| 
 | |
| 	sqlStr = fmt.Sprintf(
 | |
| 		"INSERT INTO %s (%s) VALUES (%s)",
 | |
| 		h.quoteKeyword(f.fileNameWithoutExtension()),
 | |
| 		strings.Join(sqlColumns, ", "),
 | |
| 		strings.Join(sqlValues, ", "),
 | |
| 	)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func fixturesFromFolder(folderName string) ([]*fixtureFile, error) {
 | |
| 	var files []*fixtureFile
 | |
| 	fileinfos, err := ioutil.ReadDir(folderName)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	for _, fileinfo := range fileinfos {
 | |
| 		if !fileinfo.IsDir() && filepath.Ext(fileinfo.Name()) == ".yml" {
 | |
| 			fixture := &fixtureFile{
 | |
| 				path:     path.Join(folderName, fileinfo.Name()),
 | |
| 				fileName: fileinfo.Name(),
 | |
| 			}
 | |
| 			fixture.content, err = ioutil.ReadFile(fixture.path)
 | |
| 			if err != nil {
 | |
| 				return nil, err
 | |
| 			}
 | |
| 			files = append(files, fixture)
 | |
| 		}
 | |
| 	}
 | |
| 	return files, nil
 | |
| }
 | |
| 
 | |
| func fixturesFromFiles(fileNames ...string) ([]*fixtureFile, error) {
 | |
| 	var (
 | |
| 		fixtureFiles []*fixtureFile
 | |
| 		err          error
 | |
| 	)
 | |
| 
 | |
| 	for _, f := range fileNames {
 | |
| 		fixture := &fixtureFile{
 | |
| 			path:     f,
 | |
| 			fileName: filepath.Base(f),
 | |
| 		}
 | |
| 		fixture.content, err = ioutil.ReadFile(fixture.path)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		fixtureFiles = append(fixtureFiles, fixture)
 | |
| 	}
 | |
| 
 | |
| 	return fixtureFiles, nil
 | |
| }
 | |
| 
 |