You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							132 lines
						
					
					
						
							2.8 KiB
						
					
					
				
			
		
		
	
	
							132 lines
						
					
					
						
							2.8 KiB
						
					
					
				| package testfixtures
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"strings"
 | |
| )
 | |
| 
 | |
| // Oracle is the Oracle database helper for this package
 | |
| type Oracle struct {
 | |
| 	baseHelper
 | |
| 
 | |
| 	enabledConstraints []oracleConstraint
 | |
| 	sequences          []string
 | |
| }
 | |
| 
 | |
| type oracleConstraint struct {
 | |
| 	tableName      string
 | |
| 	constraintName string
 | |
| }
 | |
| 
 | |
| func (h *Oracle) init(db *sql.DB) error {
 | |
| 	var err error
 | |
| 
 | |
| 	h.enabledConstraints, err = h.getEnabledConstraints(db)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	h.sequences, err = h.getSequences(db)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (*Oracle) paramType() int {
 | |
| 	return paramTypeColon
 | |
| }
 | |
| 
 | |
| func (*Oracle) quoteKeyword(str string) string {
 | |
| 	return fmt.Sprintf("\"%s\"", strings.ToUpper(str))
 | |
| }
 | |
| 
 | |
| func (*Oracle) databaseName(db *sql.DB) (dbName string) {
 | |
| 	db.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (*Oracle) getEnabledConstraints(db *sql.DB) ([]oracleConstraint, error) {
 | |
| 	constraints := make([]oracleConstraint, 0)
 | |
| 	rows, err := db.Query(`
 | |
|         SELECT table_name, constraint_name
 | |
|         FROM user_constraints
 | |
|         WHERE constraint_type = 'R'
 | |
|           AND status = 'ENABLED'
 | |
|     `)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	defer rows.Close()
 | |
| 	for rows.Next() {
 | |
| 		var constraint oracleConstraint
 | |
| 		rows.Scan(&constraint.tableName, &constraint.constraintName)
 | |
| 		constraints = append(constraints, constraint)
 | |
| 	}
 | |
| 	return constraints, nil
 | |
| }
 | |
| 
 | |
| func (*Oracle) getSequences(db *sql.DB) ([]string, error) {
 | |
| 	sequences := make([]string, 0)
 | |
| 	rows, err := db.Query("SELECT sequence_name FROM user_sequences")
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	defer rows.Close()
 | |
| 	for rows.Next() {
 | |
| 		var sequence string
 | |
| 		rows.Scan(&sequence)
 | |
| 		sequences = append(sequences, sequence)
 | |
| 	}
 | |
| 	return sequences, nil
 | |
| }
 | |
| 
 | |
| func (h *Oracle) resetSequences(db *sql.DB) error {
 | |
| 	for _, sequence := range h.sequences {
 | |
| 		_, err := db.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		_, err = db.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
 | |
| 	// re-enable after load
 | |
| 	defer func() {
 | |
| 		for _, c := range h.enabledConstraints {
 | |
| 			db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// disable foreign keys
 | |
| 	for _, c := range h.enabledConstraints {
 | |
| 		_, err := db.Exec(fmt.Sprintf("ALTER TABLE %s DISABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	tx, err := db.Begin()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if err = loadFn(tx); err != nil {
 | |
| 		tx.Rollback()
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if err = tx.Commit(); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return h.resetSequences(db)
 | |
| }
 | |
| 
 |