package testfixtures // import ""

import (


// Loader is the responsible to loading fixtures.
type Loader struct {
	db            *sql.DB
	helper        helper
	fixturesFiles []*fixtureFile

	skipTestDatabaseCheck bool
	location              *time.Location

	template           bool
	templateFuncs      template.FuncMap
	templateLeftDelim  string
	templateRightDelim string
	templateOptions    []string
	templateData       interface{}

type fixtureFile struct {
	path       string
	fileName   string
	content    []byte
	insertSQLs []insertSQL

type insertSQL struct {
	sql    string
	params []interface{}

var (
	testDatabaseRegexp = regexp.MustCompile("(?i)test")

	errDatabaseIsRequired = fmt.Errorf("testfixtures: database is required")
	errDialectIsRequired  = fmt.Errorf("testfixtures: dialect is required")

// New instantiates a new Loader instance. The "Database" and "Driver"
// options are required.
func New(options ...func(*Loader) error) (*Loader, error) {
	l := &Loader{
		templateLeftDelim:  "{{",
		templateRightDelim: "}}",
		templateOptions:    []string{"missingkey=zero"},

	for _, option := range options {
		if err := option(l); err != nil {
			return nil, err

	if l.db == nil {
		return nil, errDatabaseIsRequired
	if l.helper == nil {
		return nil, errDialectIsRequired

	if err := l.helper.init(l.db); err != nil {
		return nil, err
	if err := l.buildInsertSQLs(); err != nil {
		return nil, err

	return l, nil

// Database sets an existing sql.DB instant to Loader.
func Database(db *sql.DB) func(*Loader) error {
	return func(l *Loader) error {
		l.db = db
		return nil

// Dialect informs Loader about which database dialect you're using.
// Possible options are "postgresql", "timescaledb", "mysql", "mariadb",
// "sqlite" and "sqlserver".
func Dialect(dialect string) func(*Loader) error {
	return func(l *Loader) error {
		h, err := helperForDialect(dialect)
		if err != nil {
			return err
		l.helper = h
		return nil

func helperForDialect(dialect string) (helper, error) {
	switch dialect {
	case "postgres", "postgresql", "timescaledb":
		return &postgreSQL{}, nil
	case "mysql", "mariadb":
		return &mySQL{}, nil
	case "sqlite", "sqlite3":
		return &sqlite{}, nil
	case "mssql", "sqlserver":
		return &sqlserver{}, nil
		return nil, fmt.Errorf(`testfixtures: unrecognized dialect "%s"`, dialect)

// UseAlterConstraint If true, the contraint disabling will do
// using ALTER CONTRAINT sintax, only allowed in PG >= 9.4.
// If false, the constraint disabling will use DISABLE TRIGGER ALL,
// which requires SUPERUSER privileges.
// Only valid for PostgreSQL. Returns an error otherwise.
func UseAlterConstraint() func(*Loader) error {
	return func(l *Loader) error {
		pgHelper, ok := l.helper.(*postgreSQL)
		if !ok {
			return fmt.Errorf("testfixtures: UseAlterConstraint is only valid for PostgreSQL databases")
		pgHelper.useAlterConstraint = true
		return nil

// SkipResetSequences prevents Loader from reseting sequences after loading
// fixtures.
// Only valid for PostgreSQL. Returns an error otherwise.
func SkipResetSequences() func(*Loader) error {
	return func(l *Loader) error {
		pgHelper, ok := l.helper.(*postgreSQL)
		if !ok {
			return fmt.Errorf("testfixtures: SkipResetSequences is only valid for PostgreSQL databases")
		pgHelper.skipResetSequences = true
		return nil

// ResetSequencesTo sets the value the sequences will be reset to.
// Defaults to 10000.
// Only valid for PostgreSQL. Returns an error otherwise.
func ResetSequencesTo(value int64) func(*Loader) error {
	return func(l *Loader) error {
		pgHelper, ok := l.helper.(*postgreSQL)
		if !ok {
			return fmt.Errorf("testfixtures: ResetSequencesTo is only valid for PostgreSQL databases")
		pgHelper.resetSequencesTo = value
		return nil

// DangerousSkipTestDatabaseCheck will make Loader not check if the database
// name contains "test". Use with caution!
func DangerousSkipTestDatabaseCheck() func(*Loader) error {
	return func(l *Loader) error {
		l.skipTestDatabaseCheck = true
		return nil

// Directory informs Loader to load YAML files from a given directory.
func Directory(dir string) func(*Loader) error {
	return func(l *Loader) error {
		fixtures, err := l.fixturesFromDir(dir)
		if err != nil {
			return err
		l.fixturesFiles = append(l.fixturesFiles, fixtures...)
		return nil

// Files informs Loader to load a given set of YAML files.
func Files(files ...string) func(*Loader) error {
	return func(l *Loader) error {
		fixtures, err := l.fixturesFromFiles(files...)
		if err != nil {
			return err
		l.fixturesFiles = append(l.fixturesFiles, fixtures...)
		return nil

// Paths inform Loader to load a given set of YAML files and directories.
func Paths(paths ...string) func(*Loader) error {
	return func(l *Loader) error {
		fixtures, err := l.fixturesFromPaths(paths...)
		if err != nil {
			return err
		l.fixturesFiles = append(l.fixturesFiles, fixtures...)
		return nil

// Location makes Loader use the given location by default when parsing
// dates. If not given, by default it uses the value of time.Local.
func Location(location *time.Location) func(*Loader) error {
	return func(l *Loader) error {
		l.location = location
		return nil

// Template makes loader process each YAML file as an template using the
// text/template package.
// For more information on how templates work in Go please read:
// If not given the YAML files are parsed as is.
func Template() func(*Loader) error {
	return func(l *Loader) error {
		l.template = true
		return nil

// TemplateFuncs allow choosing which functions will be available
// when processing templates.
// For more information see:
func TemplateFuncs(funcs template.FuncMap) func(*Loader) error {
	return func(l *Loader) error {
		if !l.template {
			return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateFuns() option`)

		l.templateFuncs = funcs
		return nil

// TemplateDelims allow choosing which delimiters will be used for templating.
// This defaults to "{{" and "}}".
// For more information see
func TemplateDelims(left, right string) func(*Loader) error {
	return func(l *Loader) error {
		if !l.template {
			return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateDelims() option`)

		l.templateLeftDelim = left
		l.templateRightDelim = right
		return nil

// TemplateOptions allows you to specific which text/template options will
// be enabled when processing templates.
// This defaults to "missingkey=zero". Check the available options here:
func TemplateOptions(options ...string) func(*Loader) error {
	return func(l *Loader) error {
		if !l.template {
			return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateOptions() option`)

		l.templateOptions = options
		return nil

// TemplateData allows you to specify which data will be available
// when processing templates. Data is accesible by prefixing it with a "."
// like {{.MyKey}}.
func TemplateData(data interface{}) func(*Loader) error {
	return func(l *Loader) error {
		if !l.template {
			return fmt.Errorf(`testfixtures: the Template() options is required in order to use the TemplateData() option`)

		l.templateData = data
		return nil

// EnsureTestDatabase returns an error if the database name does not contains
// "test".
func (l *Loader) EnsureTestDatabase() error {
	dbName, err := l.helper.databaseName(l.db)
	if err != nil {
		return err
	if !testDatabaseRegexp.MatchString(dbName) {
		return fmt.Errorf(`testfixtures: database "%s" does not appear to be a test database`, dbName)
	return nil

// Load wipes and after load all fixtures in the database.
//     if err := fixtures.Load(); err != nil {
//             ...
//     }
func (l *Loader) Load() error {
	if !l.skipTestDatabaseCheck {
		if err := l.EnsureTestDatabase(); err != nil {
			return err

	err := l.helper.disableReferentialIntegrity(l.db, func(tx *sql.Tx) error {
		for _, file := range l.fixturesFiles {
			modified, err := l.helper.isTableModified(tx, file.fileNameWithoutExtension())
			if err != nil {
				return err
			if !modified {
			if err := file.delete(tx, l.helper); err != nil {
				return err

			err = l.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
				for j, i := range file.insertSQLs {
					if _, err := tx.Exec(i.sql, i.params...); err != nil {
						return &InsertError{
							Err:    err,
							File:   file.fileName,
							Index:  j,
							SQL:    i.sql,
							Params: i.params,
				return nil
			if err != nil {
				return err
		return nil
	if err != nil {
		return err
	return l.helper.afterLoad(l.db)

// InsertError will be returned if any error happens on database while
// inserting the record.
type InsertError struct {
	Err    error
	File   string
	Index  int
	SQL    string
	Params []interface{}

func (e *InsertError) Error() string {
	return fmt.Sprintf(
		"testfixtures: error inserting record: %v, on file: %s, index: %d, sql: %s, params: %v",

func (l *Loader) buildInsertSQLs() error {
	for _, f := range l.fixturesFiles {
		var records interface{}
		if err := yaml.Unmarshal(f.content, &records); err != nil {
			return fmt.Errorf("testfixtures: could not unmarshal YAML: %w", err)

		switch records := records.(type) {
		case []interface{}:
			f.insertSQLs = make([]insertSQL, 0, len(records))

			for _, record := range records {
				recordMap, ok := record.(map[interface{}]interface{})
				if !ok {
					return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")

				sql, values, err := l.buildInsertSQL(f, recordMap)
				if err != nil {
					return err

				f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
		case map[interface{}]interface{}:
			f.insertSQLs = make([]insertSQL, 0, len(records))

			for _, record := range records {
				recordMap, ok := record.(map[interface{}]interface{})
				if !ok {
					return fmt.Errorf("testfixtures: could not cast record: not a map[interface{}]interface{}")

				sql, values, err := l.buildInsertSQL(f, recordMap)
				if err != nil {
					return err

				f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
			return fmt.Errorf("testfixtures: fixture is not a slice or map")

	return nil

func (f *fixtureFile) fileNameWithoutExtension() string {
	return strings.Replace(f.fileName, filepath.Ext(f.fileName), "", 1)

func (f *fixtureFile) delete(tx *sql.Tx, h helper) error {
	if _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", h.quoteKeyword(f.fileNameWithoutExtension()))); err != nil {
		return fmt.Errorf(`testfixtures: could not clean table "%s": %w`, f.fileNameWithoutExtension(), err)
	return nil

func (l *Loader) buildInsertSQL(f *fixtureFile, record map[interface{}]interface{}) (sqlStr string, values []interface{}, err error) {
	var (
		sqlColumns = make([]string, 0, len(record))
		sqlValues  = make([]string, 0, len(record))
		i          = 1
	for key, value := range record {
		keyStr, ok := key.(string)
		if !ok {
			err = fmt.Errorf("testfixtures: record map key is not a string")

		sqlColumns = append(sqlColumns, l.helper.quoteKeyword(keyStr))

		// if string, try convert to SQL or time
		// if map or array, convert to json
		switch v := value.(type) {
		case string:
			if strings.HasPrefix(v, "RAW=") {
				sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))

			if t, err := l.tryStrToDate(v); err == nil {
				value = t
		case []interface{}, map[interface{}]interface{}:
			value = recursiveToJSON(v)

		switch l.helper.paramType() {
		case paramTypeDollar:
			sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
		case paramTypeQuestion:
			sqlValues = append(sqlValues, "?")
		case paramTypeAtSign:
			sqlValues = append(sqlValues, fmt.Sprintf("@p%d", i))

		values = append(values, value)

	sqlStr = fmt.Sprintf(
		"INSERT INTO %s (%s) VALUES (%s)",
		strings.Join(sqlColumns, ", "),
		strings.Join(sqlValues, ", "),

func (l *Loader) fixturesFromDir(dir string) ([]*fixtureFile, error) {
	fileinfos, err := ioutil.ReadDir(dir)
	if err != nil {
		return nil, fmt.Errorf(`testfixtures: could not stat directory "%s": %w`, dir, err)

	files := make([]*fixtureFile, 0, len(fileinfos))

	for _, fileinfo := range fileinfos {
		fileExt := filepath.Ext(fileinfo.Name())
		if !fileinfo.IsDir() && (fileExt == ".yml" || fileExt == ".yaml") {
			fixture := &fixtureFile{
				path:     path.Join(dir, fileinfo.Name()),
				fileName: fileinfo.Name(),
			fixture.content, err = ioutil.ReadFile(fixture.path)
			if err != nil {
				return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err)
			if err := l.processFileTemplate(fixture); err != nil {
				return nil, err
			files = append(files, fixture)
	return files, nil

func (l *Loader) fixturesFromFiles(fileNames ...string) ([]*fixtureFile, error) {
	var (
		fixtureFiles = make([]*fixtureFile, 0, len(fileNames))
		err          error

	for _, f := range fileNames {
		fixture := &fixtureFile{
			path:     f,
			fileName: filepath.Base(f),
		fixture.content, err = ioutil.ReadFile(fixture.path)
		if err != nil {
			return nil, fmt.Errorf(`testfixtures: could not read file "%s": %w`, fixture.path, err)
		if err := l.processFileTemplate(fixture); err != nil {
			return nil, err
		fixtureFiles = append(fixtureFiles, fixture)

	return fixtureFiles, nil

func (l *Loader) fixturesFromPaths(paths ...string) ([]*fixtureFile, error) {
	fixtureExtractor := func(p string, isDir bool) ([]*fixtureFile, error) {
		if isDir {
			return l.fixturesFromDir(p)

		return l.fixturesFromFiles(p)

	var fixtureFiles []*fixtureFile

	for _, p := range paths {
		f, err := os.Stat(p)
		if err != nil {
			return nil, fmt.Errorf(`testfixtures: could not stat path "%s": %w`, p, err)

		fixtures, err := fixtureExtractor(p, f.IsDir())
		if err != nil {
			return nil, err

		fixtureFiles = append(fixtureFiles, fixtures...)

	return fixtureFiles, nil

func (l *Loader) processFileTemplate(f *fixtureFile) error {
	if !l.template {
		return nil

	t := template.New("").
		Delims(l.templateLeftDelim, l.templateRightDelim).
	t, err := t.Parse(string(f.content))
	if err != nil {
		return fmt.Errorf(`textfixtures: error on parsing template in %s: %w`, f.fileName, err)

	var buffer bytes.Buffer
	if err := t.Execute(&buffer, l.templateData); err != nil {
		return fmt.Errorf(`textfixtures: error on executing template in %s: %w`, f.fileName, err)

	f.content = buffer.Bytes()
	return nil