最近在做项目的时候,需要比对两个数据库的表结构差异,由于表数量比较多,人工比对的话需要大量时间,且不可复用,于是想到用 python 写一个脚本来达到诉求,下次有相同诉求的时候只需改 sql 文件名即可。
compare_diff.py:
import re
import json
# 建表语句对象
class TableStmt(object):
table_name = ""
create_stmt = ""
# 表对象
class Table(object):
table_name = ""
fields = []
indexes = []
# 字段对象
class Field(object):
field_name = ""
field_type = ""
# 索引对象
class Index(object):
name = ""
type = ""
columns = ""
# 自定义JSON序列化器,非必须,打印时可用到
def obj_2_dict(obj):
if isinstance(obj, Field):
return {
"field_name": obj.field_name,
"field_type": obj.field_type
}
elif isinstance(obj, Index):
return {
"name": obj.name,
"type": obj.type,
"columns": obj.columns
}
raise TypeError(f"Type {type(obj)} is not serializable")
# 正则表达式模式来匹配完整的建表语句
create_table_pattern = re.compile(
r"CREATE TABLE `(?P<table_name>\w+)`.*?\)\s*ENGINE[A-Za-z0-9=_ ''\n\r\u4e00-\u9fa5]+;",
re.DOTALL | re.IGNORECASE
)
# 正则表达式模式来匹配字段名和字段类型,只提取基本类型忽略其他信息
table_pattern = re.compile(
r"^\s*`(?P<field>\w+)`\s+(?P<type>[a-zA-Z]+(?:\(\d+(?:,\d+)?\))?)",
re.MULTILINE
)
# 正则表达式模式来匹配索引定义
index_pattern = re.compile(r'(?<!`)KEY\s+`?(\w+)`?\s*\(([^)]+)\)|'
r'PRIMARY\s+KEY\s*\(([^)]+)\)|'
r'UNIQUE\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)|'
r'FULLTEXT\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)',
re.IGNORECASE)
# 提取每个表名及建表语句
def extract_create_table_statements(sql_script):
matches = create_table_pattern.finditer(sql_script)
table_create_stmts = []
for match in matches:
tableStmt = TableStmt()
tableStmt.table_name = match.group('table_name').lower() # 表名统一转换成小写
tableStmt.create_stmt = match.group(0).strip() # 获取匹配到的整个建表语句
table_create_stmts.append(tableStmt)
return table_create_stmts
# 提取索引
def extract_indexes(sql):
matches = index_pattern.findall(sql)
indexes = []
for match in matches:
index = Index()
if match[0]: # 普通索引
index.type = 'index'
index.name = match[0].lower()
index.columns = match[1].lower()
elif match[2]: # 主键
index.type = 'primary key'
index.name = 'primary'
index.columns = match[2].lower()
elif match[3]: # 唯一索引
index.type = 'unique index'
index.name = match[3].lower()
index.columns = match[4].lower()
elif match[5]: # 全文索引
index.type = 'fulltext index'
index.name = match[5].lower()
index.columns = match[6].lower()
indexes.append(index)
return indexes
# 提取字段
def extract_fields(sql):
matches = table_pattern.finditer(sql)
fields = []
for match in matches:
field = Field()
field.field_name = match.group('field').lower() # 字段名统一转换成小写
field.field_type = match.group('type').lower() # 字段类型统一转换小写
fields.append(field)
return fields
# 提取表信息
def extract_table_info(tableStmt: TableStmt):
table = Table()
table.table_name = tableStmt.table_name.lower()
# 获取字段
table.fields = extract_fields(tableStmt.create_stmt)
# 获取索引
table.indexes = extract_indexes(tableStmt.create_stmt)
return table
# 提取sql脚本中所有的表
def get_all_tables(sql_script):
table_map = {}
table_stmts = extract_create_table_statements(sql_script)
for stmt in table_stmts:
table = extract_table_info(stmt)
table_map[table.table_name] = table
return table_map
# 比较两个表的字段
def compare_fields(source: Table, target: Table):
source_fields_map = {field.field_name: field for field in source.fields}
target_fields_map = {field.field_name: field for field in target.fields}
source_fields_not_in_target = []
fields_type_not_match = []
# source表有,而target表没有的字段
for field in source.fields:
if field.field_name not in target_fields_map.keys():
source_fields_not_in_target.append(field.field_name)
continue
target_field = target_fields_map.get(field.field_name)
if field.field_type != target_field.field_type:
fields_type_not_match.append(
"field=" + field.field_name + ", source type: " + field.field_type + ", target type: " + target_field.field_type)
target_fields_not_in_source = []
# target表有,而source表没有的字段
for field in target.fields:
if field.field_name not in source_fields_map.keys():
target_fields_not_in_source.append(field.field_name)
continue
# 不用再比较type了,因为如果这个字段在source和target都有的话,前面已经比较过type了
return source_fields_not_in_target, fields_type_not_match, target_fields_not_in_source
# 比较两个表的索引
def compare_indexes(source: Table, target: Table):
source_indexes_map = {index.name: index for index in source.indexes}
target_indexes_map = {index.name: index for index in target.indexes}
source_indexes_not_in_target = []
index_column_not_match = []
index_type_not_match = []
for index in source.indexes:
if index.name not in target_indexes_map.keys():
# source表有而target表没有的索引
source_indexes_not_in_target.append(index.name)
continue
target_index = target_indexes_map.get(index.name)
# 索引名相同,类型不同
if index.type != target_index.type:
index_type_not_match.append(
"name=" + index.name + ", source type: " + index.type + ", target type: " + target_index.type)
continue
# 索引名和类型都相同,字段不同
if index.columns != target_index.columns:
index_column_not_match.append(
"name=" + index.name + ", source columns=" + index.columns + ", target columns=" + target_index.columns)
target_indexes_not_in_source = []
for index in target.indexes:
if index.name not in source_indexes_map.keys():
# target表有而source表没有的索引
target_indexes_not_in_source.append(index.name)
continue
return source_indexes_not_in_target, index_column_not_match, index_type_not_match, target_indexes_not_in_source
# 打印比较的结果,如果结果为空列表(说明没有不同)则不打印
def print_diff(desc, compare_result):
if len(compare_result) > 0:
print(f"{desc} {compare_result}")
# 比较脚本里面的所有表
def compare_table(source_sql_script, target_sql_script):
source_table_map = get_all_tables(source_sql_script)
target_table_map = get_all_tables(target_sql_script)
source_table_not_in_target = []
for key, source_table in source_table_map.items():
# 只比较白名单里面的表
if len(white_list_tables) > 0 and key not in white_list_tables:
continue
# 不比较黑名单里面的表
if len(black_list_tables) > 0 and key in black_list_tables:
continue
if key not in target_table_map.keys():
# source有而target没有的表
source_table_not_in_target.append(key)
continue
target_table = target_table_map[key]
# 比较字段
(source_fields_not_in_target, fields_type_not_match
, target_fields_not_in_source) = compare_fields(source_table, target_table)
# 比较索引
(source_indexes_not_in_target, index_column_not_match
, index_type_not_match, target_indexes_not_in_source) = compare_indexes(source_table, target_table)
print(f"====== table = {key} ======")
print_diff("source field not in target, fields:", source_fields_not_in_target)
print_diff("target field not in source, fields:", target_fields_not_in_source)
print_diff("field type not match:", fields_type_not_match)
print_diff("source index not in target, indexes:", source_indexes_not_in_target)
print_diff("target index not in source, indexes:", target_indexes_not_in_source)
print_diff("index type not match:", index_type_not_match)
print_diff("index column not match:", index_column_not_match)
print("")
# 找出target有而source没有的表
target_table_not_in_source = []
for key, target_table in target_table_map.items():
# 只比较白名单里面的表
if len(white_list_tables) > 0 and key not in white_list_tables:
continue
# 不比较黑名单里面的表
if len(black_list_tables) > 0 and key in black_list_tables:
continue
if key not in source_table_map.keys():
target_table_not_in_source.append(key)
print_diff("source table not in target, table list:", source_table_not_in_target)
print_diff("target table not in source, table list:", target_table_not_in_source)
# 读取sql文件
def sql_read(file_name):
with open(file_name, "r", encoding='utf-8') as file:
return file.read()
def print_all_tables():
table_map = get_all_tables(sql_read("sql1.sql"))
for key, item in table_map.items():
print(key)
print(json.dumps(item.fields, default=obj_2_dict, ensure_ascii=False, indent=4))
print(json.dumps(item.indexes, default=obj_2_dict, ensure_ascii=False, indent=4))
print("")
# print_all_tables()
# 黑白名单设置,适用于只比较所有表中一部分表的情况
# 白名单表,不为空的话,只比较这里面的表
white_list_tables = []
# 黑名单表,不为空的话,不比较这里面的表
black_list_tables = []
if __name__ == '__main__':
# 说明:mysql默认大小写不敏感,如果数据库设置了大小写敏感,脚本需要修改,里面所有的表名、字段名和索引名都默认转了小写再去比较的
source_script = sql_read("sql1.sql")
target_script = sql_read("sql2.sql")
compare_table(source_script, target_script)
运行效果如下:
====== table = table1 ======
source field not in target, fields: ['age', 'email']
target field not in source, fields: ['name']
field type not match: ['field=created_at, source type: date, target type: bigint(20)', 'field=updated_at, source type: timestamp, target type: date']
source index not in target, indexes: ['unique_name']
target index not in source, indexes: ['idx_country_env']
====== table = table2 ======
index type not match: ['name=fulltext_index, source type: fulltext index, target type: index']
index column not match: ['name=index, source columns=`age`, target columns=`description`']
====== table = table3 ======
index column not match: ['name=primary, source columns=`id`, `value`, target columns=`value`, `id`']
source table not in target, table list: ['activity_instance']
target table not in source, table list: ['table5']
结果说明:
- 按照 table 来打印 source table 和 target table 的字段和索引差异,此时 table 在两个 sql 脚本里都存在
- 最后打印只在其中一个 sql 脚本里存在的 table list
sql1.sql:
CREATE TABLE `table1` (
`id` INT(11) NOT NULL AUTO_INCREMENT,
`age` INT(11) DEFAULT NULL,
`email` varchar(32) DEFAULT NULL COMMENT '邮箱',
`created_at` date DEFAULT NULL,
`updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
UNIQUE KEY `unique_name` (`name`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表';
CREATE TABLE `table2` (
`id` INT(11) NOT NULL,
`description` TEXT NOT NULL,
`created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
UNIQUE KEY `unique_name` (`name`),
KEY `index` (`age`),
FULLTEXT KEY `fulltext_index` (`name`, `age`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
CREATE TABLE `table3` (
`id` INT(11) NOT NULL AUTO_INCREMENT,
`value` DECIMAL(10,2) NOT NULL,
`updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`, `value`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
/******************************************/
/* DatabaseName = database */
/* TableName = activity_instance */
/******************************************/
CREATE TABLE `activity_instance`
(
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT '主键',
`gmt_create` bigint(20) NOT NULL COMMENT '创建时间',
`gmt_modified` bigint(20) NOT NULL COMMENT '修改时间',
`activity_name` varchar(400) NOT NULL COMMENT '活动名称',
`benefit_type` varchar(16) DEFAULT NULL,
`benefit_id` varchar(32) DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `idx_country_env` (`env`, `country_code`),
KEY `idx_benefit_type_id` (`benefit_type`, `benefit_id`)
) ENGINE = InnoDB
AUTO_INCREMENT = 139
DEFAULT CHARSET = utf8mb4 COMMENT ='活动时间模板表'
;
sql2.sql:
CREATE TABLE `TABLE1` (
`id` INT(11) NOT NULL AUTO_INCREMENT,
`name` VARCHAR(255) NOT NULL,
`created_at` bigint(20) DEFAULT NULL,
`updated_at` date ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
KEY `idx_country_env` (`env`, `country_code`),
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表';
CREATE TABLE `table2` (
`id` INT(11) NOT NULL,
`description` TEXT NOT NULL,
`created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
UNIQUE KEY `unique_name` (`name`),
KEY `index` (`description`),
KEY `fulltext_index` (`name`, `age`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
CREATE TABLE `table3` (
`id` INT(11) NOT NULL AUTO_INCREMENT,
`value` DECIMAL(10,2) NOT NULL,
`updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`value`, `id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
CREATE TABLE `TABLE5` (
`id` INT(11) NOT NULL AUTO_INCREMENT,
`value` DECIMAL(10,2) NOT NULL,
`updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
把 python 和 sql 脚本拷贝下来分别放在同一个目录下的3个文件中即可,示例在 python 3.12 环境上成功运行。