From 7f55c264092ef83f7ca54f7921affc0aca441866 Mon Sep 17 00:00:00 2001 From: renliangyu857 <2918490262@qq.com> Date: Tue, 18 Oct 2022 15:01:38 +0800 Subject: [PATCH] feature: support mysql update join sql (#4914) --- .../registry/MultiRegistryFactory.java | 17 +- .../exec/BaseTransactionalExecutor.java | 72 +++++- .../rm/datasource/exec/ExecuteTemplate.java | 11 + .../rm/datasource/exec/UpdateExecutor.java | 62 +++-- .../exec/mysql/MySQLUpdateJoinExecutor.java | 217 ++++++++++++++++++ .../datasource/undo/UndoExecutorFactory.java | 1 + .../exec/UpdateJoinExecutorTest.java | 96 ++++++++ .../datasource/mock/MockDatabaseMetaData.java | 14 +- .../mock/MockExecuteHandlerImpl.java | 39 +++- .../main/java/io/seata/sqlparser/SQLType.java | 6 +- .../seata/sqlparser/SQLUpdateRecognizer.java | 11 + .../seata/sqlparser/druid/BaseRecognizer.java | 8 - .../druid/mysql/MySQLUpdateRecognizer.java | 100 ++++++-- .../druid/oracle/BaseOracleRecognizer.java | 23 +- .../postgresql/BasePostgresqlRecognizer.java | 65 ++++++ .../druid/DruidSQLRecognizerFactoryTest.java | 68 +++--- .../druid/MySQLUpdateRecognizerTest.java | 9 + test/pom.xml | 6 + .../seata/at/mysql/MysqlUpdateJoinTest.java | 166 ++++++++++++++ test/src/test/resources/README.md | 21 ++ 20 files changed, 895 insertions(+), 117 deletions(-) create mode 100644 rm-datasource/src/main/java/io/seata/rm/datasource/exec/mysql/MySQLUpdateJoinExecutor.java create mode 100644 rm-datasource/src/test/java/io/seata/rm/datasource/exec/UpdateJoinExecutorTest.java create mode 100644 test/src/test/java/io/seata/at/mysql/MysqlUpdateJoinTest.java create mode 100644 test/src/test/resources/README.md diff --git a/discovery/seata-discovery-core/src/main/java/io/seata/discovery/registry/MultiRegistryFactory.java b/discovery/seata-discovery-core/src/main/java/io/seata/discovery/registry/MultiRegistryFactory.java index 25424847530..2ded7a176ad 100644 --- a/discovery/seata-discovery-core/src/main/java/io/seata/discovery/registry/MultiRegistryFactory.java +++ b/discovery/seata-discovery-core/src/main/java/io/seata/discovery/registry/MultiRegistryFactory.java @@ -1,14 +1,17 @@ /* - * Copyright 1999-2019 Seata.io Group. + * Copyright 1999-2019 Seata.io Group. * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.seata.discovery.registry; diff --git a/rm-datasource/src/main/java/io/seata/rm/datasource/exec/BaseTransactionalExecutor.java b/rm-datasource/src/main/java/io/seata/rm/datasource/exec/BaseTransactionalExecutor.java index 5c82d14dd9e..45a3c9a5a4a 100644 --- a/rm-datasource/src/main/java/io/seata/rm/datasource/exec/BaseTransactionalExecutor.java +++ b/rm-datasource/src/main/java/io/seata/rm/datasource/exec/BaseTransactionalExecutor.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import java.util.StringJoiner; import java.util.TreeSet; @@ -205,6 +204,35 @@ protected String getColumnNameInSQL(String columnName) { return tableAlias == null ? columnName : tableAlias + "." + columnName; } + + /** + * Gets column name with table prefix + * + * @param table the table name + * @param tableAlias the tableAlias + * @param columnName the column name + * @return + */ + protected String getColumnNameWithTablePrefix(String table, String tableAlias, String columnName) { + return tableAlias == null ? (table == null ? columnName : table + "." + columnName) : (tableAlias + "." + columnName); + } + + /** + * Gets column name with table prefix + * + * @param table the table name + * @param tableAlias the tableAlias + * @param columnNames the column names + * @return + */ + protected List getColumnNamesWithTablePrefixList(String table,String tableAlias,List columnNames) { + List columnNameWithTablePrefix = new ArrayList<>(); + for (String columnName : columnNames) { + columnNameWithTablePrefix.add(this.getColumnNameWithTablePrefix(table,tableAlias,columnName)); + } + return columnNameWithTablePrefix; + } + /** * Gets several column name in sql. * @@ -212,7 +240,7 @@ protected String getColumnNameInSQL(String columnName) { * @return the column name in sql */ protected String getColumnNamesInSQL(List columnNameList) { - if (Objects.isNull(columnNameList) || columnNameList.isEmpty()) { + if (CollectionUtils.isEmpty(columnNameList)) { return null; } StringBuilder columnNamesStr = new StringBuilder(); @@ -225,6 +253,28 @@ protected String getColumnNamesInSQL(List columnNameList) { return columnNamesStr.toString(); } + /** + * Gets several column name in sql. + * + * @param table the table + * @param tableAlias the table alias + * @param columnNameList the column name + * @return the column name in sql + */ + protected String getColumnNamesWithTablePrefix(String table,String tableAlias, List columnNameList) { + if (CollectionUtils.isEmpty(columnNameList)) { + return null; + } + StringBuilder columnNamesStr = new StringBuilder(); + for (int i = 0; i < columnNameList.size(); i++) { + if (i > 0) { + columnNamesStr.append(" , "); + } + columnNamesStr.append(getColumnNameWithTablePrefix(table,tableAlias, columnNameList.get(i))); + } + return columnNamesStr.toString(); + } + /** * Gets from table in sql. * @@ -268,13 +318,28 @@ protected TableMeta getTableMeta(String tableName) { * @return true: contains pk false: not contains pk */ protected boolean containsPK(List columns) { - if (columns == null || columns.isEmpty()) { + if (CollectionUtils.isEmpty(columns)) { return false; } List newColumns = ColumnUtils.delEscape(columns, getDbType()); return getTableMeta().containsPK(newColumns); } + /** + * the columns contains table meta pk + * + * @param tableName the tableName + * @param columns the column name list + * @return true: contains pk false: not contains pk + */ + protected boolean containsPK(String tableName,List columns) { + if (CollectionUtils.isEmpty(columns)) { + return false; + } + List newColumns = ColumnUtils.delEscape(columns, getDbType()); + return getTableMeta(tableName).containsPK(newColumns); + } + /** * compare column name and primary key name @@ -384,7 +449,6 @@ protected SQLUndoLog buildUndoItem(TableRecords beforeImage, TableRecords afterI return sqlUndoLog; } - /** * build a BeforeImage * diff --git a/rm-datasource/src/main/java/io/seata/rm/datasource/exec/ExecuteTemplate.java b/rm-datasource/src/main/java/io/seata/rm/datasource/exec/ExecuteTemplate.java index 7006d1bb3c5..d317cd23d19 100644 --- a/rm-datasource/src/main/java/io/seata/rm/datasource/exec/ExecuteTemplate.java +++ b/rm-datasource/src/main/java/io/seata/rm/datasource/exec/ExecuteTemplate.java @@ -26,8 +26,10 @@ import io.seata.core.model.BranchType; import io.seata.rm.datasource.StatementProxy; import io.seata.rm.datasource.exec.mysql.MySQLInsertOnDuplicateUpdateExecutor; +import io.seata.rm.datasource.exec.mysql.MySQLUpdateJoinExecutor; import io.seata.rm.datasource.sql.SQLVisitorFactory; import io.seata.sqlparser.SQLRecognizer; +import io.seata.sqlparser.SQLType; import io.seata.sqlparser.util.JdbcConstants; /** @@ -113,6 +115,15 @@ public static T execute(List sqlRecogniz throw new NotSupportYetException(dbType + " not support to INSERT_ON_DUPLICATE_UPDATE"); } break; + case UPDATE_JOIN: + switch (dbType) { + case JdbcConstants.MYSQL: + executor = new MySQLUpdateJoinExecutor<>(statementProxy,statementCallback,sqlRecognizer); + break; + default: + throw new NotSupportYetException(dbType + " not support to " + SQLType.UPDATE_JOIN.name()); + } + break; default: executor = new PlainExecutor<>(statementProxy, statementCallback); break; diff --git a/rm-datasource/src/main/java/io/seata/rm/datasource/exec/UpdateExecutor.java b/rm-datasource/src/main/java/io/seata/rm/datasource/exec/UpdateExecutor.java index d66824cf4e5..491ea23017b 100644 --- a/rm-datasource/src/main/java/io/seata/rm/datasource/exec/UpdateExecutor.java +++ b/rm-datasource/src/main/java/io/seata/rm/datasource/exec/UpdateExecutor.java @@ -36,6 +36,7 @@ import io.seata.rm.datasource.sql.struct.TableRecords; import io.seata.sqlparser.SQLRecognizer; import io.seata.sqlparser.SQLUpdateRecognizer; +import io.seata.common.util.CollectionUtils; /** * The type Update executor. @@ -49,7 +50,7 @@ public class UpdateExecutor extends AbstractDMLBaseExecu private static final Configuration CONFIG = ConfigurationFactory.getInstance(); private static final boolean ONLY_CARE_UPDATE_COLUMNS = CONFIG.getBoolean( - ConfigurationKeys.TRANSACTION_UNDO_ONLY_CARE_UPDATE_COLUMNS, DefaultValues.DEFAULT_ONLY_CARE_UPDATE_COLUMNS); + ConfigurationKeys.TRANSACTION_UNDO_ONLY_CARE_UPDATE_COLUMNS, DefaultValues.DEFAULT_ONLY_CARE_UPDATE_COLUMNS); /** * Instantiates a new Update executor. @@ -59,7 +60,7 @@ public class UpdateExecutor extends AbstractDMLBaseExecu * @param sqlRecognizer the sql recognizer */ public UpdateExecutor(StatementProxy statementProxy, StatementCallback statementCallback, - SQLRecognizer sqlRecognizer) { + SQLRecognizer sqlRecognizer) { super(statementProxy, statementCallback, sqlRecognizer); } @@ -73,7 +74,6 @@ protected TableRecords beforeImage() throws SQLException { private String buildBeforeImageSQL(TableMeta tableMeta, ArrayList> paramAppenderList) { SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer; - List updateColumns = recognizer.getUpdateColumnsIsSimplified(); StringBuilder prefix = new StringBuilder("SELECT "); StringBuilder suffix = new StringBuilder(" FROM ").append(getFromTableInSQL()); String whereCondition = buildWhereCondition(recognizer, paramAppenderList); @@ -90,24 +90,9 @@ private String buildBeforeImageSQL(TableMeta tableMeta, ArrayList> } suffix.append(" FOR UPDATE"); StringJoiner selectSQLJoin = new StringJoiner(", ", prefix.toString(), suffix.toString()); - if (ONLY_CARE_UPDATE_COLUMNS) { - if (!containsPK(updateColumns)) { - selectSQLJoin.add(getColumnNamesInSQL(tableMeta.getEscapePkNameList(getDbType()))); - } - for (String columnName : updateColumns) { - selectSQLJoin.add(columnName); - } - - // The on update xxx columns will be auto update by db, so it's also the actually updated columns - List onUpdateColumns = tableMeta.getOnUpdateColumnsOnlyName(); - onUpdateColumns.removeAll(updateColumns); - for (String onUpdateColumn : onUpdateColumns) { - selectSQLJoin.add(ColumnUtils.addEscape(onUpdateColumn, getDbType())); - } - } else { - for (String columnName : tableMeta.getAllColumns().keySet()) { - selectSQLJoin.add(ColumnUtils.addEscape(columnName, getDbType())); - } + List needUpdateColumns = getNeedUpdateColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), recognizer.getUpdateColumnsIsSimplified()); + for (String needUpdateColumn : needUpdateColumns) { + selectSQLJoin.add(needUpdateColumn); } return selectSQLJoin.toString(); } @@ -134,28 +119,37 @@ private String buildAfterImageSQL(TableMeta tableMeta, TableRecords beforeImage) String whereSql = SqlGenerateUtils.buildWhereConditionByPKs(tableMeta.getPrimaryKeyOnlyName(), beforeImage.pkRows().size(), getDbType()); String suffix = " FROM " + getFromTableInSQL() + " WHERE " + whereSql; StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix.toString(), suffix); + SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer; + List needUpdateColumns = getNeedUpdateColumns(tableMeta.getTableName(), sqlRecognizer.getTableAlias(), recognizer.getUpdateColumnsIsSimplified()); + for (String needUpdateColumn : needUpdateColumns) { + selectSQLJoiner.add(needUpdateColumn); + } + return selectSQLJoiner.toString(); + } + + protected List getNeedUpdateColumns(String table, String tableAlias, List originUpdateColumns) { + List needUpdateColumns = new ArrayList<>(); + TableMeta tableMeta = getTableMeta(table); if (ONLY_CARE_UPDATE_COLUMNS) { - SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer; - List updateColumns = recognizer.getUpdateColumnsIsSimplified(); - if (!containsPK(updateColumns)) { - selectSQLJoiner.add(getColumnNamesInSQL(tableMeta.getEscapePkNameList(getDbType()))); - } - for (String columnName : updateColumns) { - selectSQLJoiner.add(columnName); + if (!containsPK(table, originUpdateColumns)) { + List pkNameList = tableMeta.getEscapePkNameList(getDbType()); + if (CollectionUtils.isNotEmpty(pkNameList)) { + needUpdateColumns.add(getColumnNamesWithTablePrefix(table,tableAlias,pkNameList)); + } } + needUpdateColumns.addAll(originUpdateColumns); // The on update xxx columns will be auto update by db, so it's also the actually updated columns List onUpdateColumns = tableMeta.getOnUpdateColumnsOnlyName(); - onUpdateColumns.removeAll(updateColumns); + onUpdateColumns.removeAll(originUpdateColumns); for (String onUpdateColumn : onUpdateColumns) { - selectSQLJoiner.add(ColumnUtils.addEscape(onUpdateColumn, getDbType())); + needUpdateColumns.add(ColumnUtils.addEscape(onUpdateColumn, getDbType())); } } else { for (String columnName : tableMeta.getAllColumns().keySet()) { - selectSQLJoiner.add(ColumnUtils.addEscape(columnName, getDbType())); + needUpdateColumns.add(ColumnUtils.addEscape(columnName, getDbType())); } } - return selectSQLJoiner.toString(); + return needUpdateColumns; } - -} +} \ No newline at end of file diff --git a/rm-datasource/src/main/java/io/seata/rm/datasource/exec/mysql/MySQLUpdateJoinExecutor.java b/rm-datasource/src/main/java/io/seata/rm/datasource/exec/mysql/MySQLUpdateJoinExecutor.java new file mode 100644 index 00000000000..1895968174c --- /dev/null +++ b/rm-datasource/src/main/java/io/seata/rm/datasource/exec/mysql/MySQLUpdateJoinExecutor.java @@ -0,0 +1,217 @@ +/* + * Copyright 1999-2019 Seata.io Group. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.seata.rm.datasource.exec.mysql; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.StringJoiner; + +import io.seata.common.util.CollectionUtils; +import io.seata.rm.datasource.ConnectionProxy; +import io.seata.rm.datasource.sql.struct.TableMetaCacheFactory; +import io.seata.rm.datasource.undo.SQLUndoLog; +import io.seata.sqlparser.SQLType; +import io.seata.common.exception.ShouldNeverHappenException; +import io.seata.common.util.IOUtil; +import io.seata.common.util.StringUtils; +import io.seata.rm.datasource.SqlGenerateUtils; +import io.seata.rm.datasource.StatementProxy; +import io.seata.rm.datasource.exec.StatementCallback; +import io.seata.rm.datasource.exec.UpdateExecutor; +import io.seata.rm.datasource.sql.struct.TableMeta; +import io.seata.rm.datasource.sql.struct.TableRecords; +import io.seata.sqlparser.SQLRecognizer; +import io.seata.sqlparser.SQLUpdateRecognizer; + + + +/** + * @author renliangyu857 + */ +public class MySQLUpdateJoinExecutor extends UpdateExecutor { + private static final String DOT = "."; + private final Map beforeImagesMap = new LinkedHashMap<>(4); + private final Map afterImagesMap = new LinkedHashMap<>(4); + + /** + * Instantiates a new Update executor. + * + * @param statementProxy the statement proxy + * @param statementCallback the statement callback + * @param sqlRecognizer the sql recognizer + */ + public MySQLUpdateJoinExecutor(StatementProxy statementProxy, StatementCallback statementCallback, + SQLRecognizer sqlRecognizer) { + super(statementProxy, statementCallback, sqlRecognizer); + } + + @Override + protected TableRecords beforeImage() throws SQLException { + ArrayList> paramAppenderList = new ArrayList<>(); + SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer; + String tableNames = recognizer.getTableName(); + // update join sql,like update t1 inner join t2 on t1.id = t2.id set t1.name = ?; tableItems = {"update t1 inner join t2","t1","t2"} + String[] tableItems = tableNames.split(recognizer.MULTI_TABLE_NAME_SEPERATOR); + String joinTable = tableItems[0]; + int itemTableIndex = 1; + for (int i = itemTableIndex; i < tableItems.length; i++) { + List itemTableUpdateColumns = getItemUpdateColumns(this.getTableMeta(tableItems[i]), recognizer.getUpdateColumns()); + if (CollectionUtils.isEmpty(itemTableUpdateColumns)) { + continue; + } + String selectSQL = buildBeforeImageSQL(joinTable, tableItems[i], itemTableUpdateColumns, paramAppenderList); + TableRecords tableRecords = buildTableRecords(getTableMeta(tableItems[i]), selectSQL, paramAppenderList); + beforeImagesMap.put(tableItems[i], tableRecords); + } + return null; + } + + private String buildBeforeImageSQL(String joinTable, String itemTable, List itemTableUpdateColumns, + ArrayList> paramAppenderList) { + SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer; + StringBuilder prefix = new StringBuilder("SELECT "); + StringBuilder suffix = new StringBuilder(" FROM ").append(joinTable); + String whereCondition = buildWhereCondition(recognizer, paramAppenderList); + String orderByCondition = buildOrderCondition(recognizer, paramAppenderList); + String limitCondition = buildLimitCondition(recognizer, paramAppenderList); + if (StringUtils.isNotBlank(whereCondition)) { + suffix.append(WHERE).append(whereCondition); + } + if (StringUtils.isNotBlank(orderByCondition)) { + suffix.append(" ").append(orderByCondition); + } + if (StringUtils.isNotBlank(limitCondition)) { + suffix.append(" ").append(limitCondition); + } + suffix.append(" FOR UPDATE"); + StringJoiner selectSQLJoin = new StringJoiner(", ", prefix.toString(), suffix.toString()); + List needUpdateColumns = getNeedUpdateColumns(itemTable, recognizer.getTableAlias(itemTable), itemTableUpdateColumns); + for (String needUpdateColumn : needUpdateColumns) { + selectSQLJoin.add(needUpdateColumn); + } + return selectSQLJoin.toString(); + } + + @Override + protected TableRecords afterImage(TableRecords beforeImage) throws SQLException { + SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer; + String tableNames = recognizer.getTableName(); + String[] tableItems = tableNames.split(recognizer.MULTI_TABLE_NAME_SEPERATOR); + String joinTable = tableItems[0]; + int itemTableIndex = 1; + for (int i = itemTableIndex; i < tableItems.length; i++) { + TableRecords tableBeforeImage = beforeImagesMap.get(tableItems[i]); + if (tableBeforeImage == null) { + continue; + } + String selectSQL = buildAfterImageSQL(joinTable, tableItems[i], tableBeforeImage); + ResultSet rs = null; + try (PreparedStatement pst = statementProxy.getConnection().prepareStatement(selectSQL)) { + SqlGenerateUtils.setParamForPk(tableBeforeImage.pkRows(), getTableMeta(tableItems[i]).getPrimaryKeyOnlyName(), pst); + rs = pst.executeQuery(); + TableRecords afterImage = TableRecords.buildRecords(getTableMeta(tableItems[i]), rs); + afterImagesMap.put(tableItems[i], afterImage); + } finally { + IOUtil.close(rs); + } + } + return null; + } + + private String buildAfterImageSQL(String joinTable, String itemTable, + TableRecords beforeImage) throws SQLException { + SQLUpdateRecognizer recognizer = (SQLUpdateRecognizer) sqlRecognizer; + TableMeta itemTableMeta = getTableMeta(itemTable); + StringBuilder prefix = new StringBuilder("SELECT "); + String whereSql = SqlGenerateUtils.buildWhereConditionByPKs(getColumnNamesWithTablePrefixList(itemTable, recognizer.getTableAlias(itemTable), itemTableMeta.getPrimaryKeyOnlyName()), beforeImage.pkRows().size(), getDbType()); + String suffix = " FROM " + joinTable + " WHERE " + whereSql; + StringJoiner selectSQLJoiner = new StringJoiner(", ", prefix.toString(), suffix); + List itemTableUpdateColumns = getItemUpdateColumns(itemTableMeta, recognizer.getUpdateColumns()); + List needUpdateColumns = getNeedUpdateColumns(itemTable, recognizer.getTableAlias(itemTable), itemTableUpdateColumns); + for (String needUpdateColumn : needUpdateColumns) { + selectSQLJoiner.add(needUpdateColumn); + } + return selectSQLJoiner.toString(); + } + + private List getItemUpdateColumns(TableMeta itemTableMeta, List updateColumns) { + List itemUpdateColumns = new ArrayList<>(); + Set itemTableAllColumns = itemTableMeta.getAllColumns().keySet(); + String itemTableName = itemTableMeta.getTableName(); + String itemTableNameAlias = ((SQLUpdateRecognizer) sqlRecognizer).getTableAlias(itemTableName); + for (String updateColumn : updateColumns) { + if (updateColumn.contains(DOT)) { + String[] specificTableColumn = updateColumn.split("\\."); + String tableNamePrefix = specificTableColumn[0]; + String column = specificTableColumn[1]; + if ((tableNamePrefix.equals(itemTableName) || tableNamePrefix.equals(itemTableNameAlias)) && itemTableAllColumns.contains(column)) { + itemUpdateColumns.add(updateColumn); + } + } else if (itemTableAllColumns.contains(updateColumn)) { + itemUpdateColumns.add(updateColumn); + } + } + return itemUpdateColumns; + } + + @Override + protected void prepareUndoLog(TableRecords beforeImage, TableRecords afterImage) throws SQLException { + if (CollectionUtils.isEmpty(beforeImagesMap) || CollectionUtils.isEmpty(afterImagesMap)) { + throw new IllegalStateException("images can not be null"); + } + for (Map.Entry entry : beforeImagesMap.entrySet()) { + String tableName = entry.getKey(); + TableRecords tableBeforeImage = entry.getValue(); + TableRecords tableAfterImage = afterImagesMap.get(tableName); + if (tableBeforeImage.getRows().size() != tableAfterImage.getRows().size()) { + throw new ShouldNeverHappenException("Before image size is not equaled to after image size, probably because you updated the primary keys."); + } + super.prepareUndoLog(tableBeforeImage, tableAfterImage); + } + } + + @Override + protected TableMeta getTableMeta(String tableName) { + ConnectionProxy connectionProxy = statementProxy.getConnectionProxy(); + return TableMetaCacheFactory.getTableMetaCache(connectionProxy.getDbType()) + .getTableMeta(connectionProxy.getTargetConnection(), tableName, connectionProxy.getDataSourceProxy().getResourceId()); + } + + /** + * build a SQLUndoLog + * + * @param beforeImage the before image + * @param afterImage the after image + * @return sql undo log + */ + protected SQLUndoLog buildUndoItem(TableRecords beforeImage, TableRecords afterImage) { + SQLType sqlType = sqlRecognizer.getSQLType(); + String tableName = beforeImage.getTableName(); + SQLUndoLog sqlUndoLog = new SQLUndoLog(); + sqlUndoLog.setSqlType(sqlType); + sqlUndoLog.setTableName(tableName); + sqlUndoLog.setBeforeImage(beforeImage); + sqlUndoLog.setAfterImage(afterImage); + return sqlUndoLog; + } +} diff --git a/rm-datasource/src/main/java/io/seata/rm/datasource/undo/UndoExecutorFactory.java b/rm-datasource/src/main/java/io/seata/rm/datasource/undo/UndoExecutorFactory.java index 8ca0c326ec6..8f368d4c74a 100644 --- a/rm-datasource/src/main/java/io/seata/rm/datasource/undo/UndoExecutorFactory.java +++ b/rm-datasource/src/main/java/io/seata/rm/datasource/undo/UndoExecutorFactory.java @@ -39,6 +39,7 @@ public static AbstractUndoExecutor getUndoExecutor(String dbType, SQLUndoLog sql result = holder.getInsertExecutor(sqlUndoLog); break; case UPDATE: + case UPDATE_JOIN: result = holder.getUpdateExecutor(sqlUndoLog); break; case DELETE: diff --git a/rm-datasource/src/test/java/io/seata/rm/datasource/exec/UpdateJoinExecutorTest.java b/rm-datasource/src/test/java/io/seata/rm/datasource/exec/UpdateJoinExecutorTest.java new file mode 100644 index 00000000000..971dfd03373 --- /dev/null +++ b/rm-datasource/src/test/java/io/seata/rm/datasource/exec/UpdateJoinExecutorTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 1999-2019 Seata.io Group. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.seata.rm.datasource.exec; + +import java.lang.reflect.Field; +import java.sql.SQLException; +import java.sql.Types; +import java.util.List; + +import com.alibaba.druid.mock.MockStatement; +import com.alibaba.druid.mock.MockStatementBase; +import com.alibaba.druid.pool.DruidDataSource; +import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.util.JdbcConstants; +import com.google.common.collect.Lists; +import io.seata.rm.datasource.ConnectionProxy; +import io.seata.rm.datasource.DataSourceProxy; +import io.seata.rm.datasource.StatementProxy; +import io.seata.rm.datasource.exec.mysql.MySQLUpdateJoinExecutor; +import io.seata.rm.datasource.mock.MockDriver; +import io.seata.rm.datasource.sql.struct.TableRecords; +import io.seata.sqlparser.druid.mysql.MySQLUpdateRecognizer; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** + * @author renliangyu857 + */ +public class UpdateJoinExecutorTest { + @Test + public void testUpdateJoinUndoLog() throws SQLException { + List returnValueColumnLabels = Lists.newArrayList("id", "name"); + Object[][] columnMetas = new Object[][]{ + new Object[]{"", "", "t1", "id", Types.INTEGER, "INTEGER", 64, 0, 10, 1, "", "", 0, 0, 64, 1, "NO", "YES"}, + new Object[]{"", "", "t1", "name", Types.VARCHAR, "VARCHAR", 64, 0, 10, 0, "", "", 0, 0, 64, 2, "YES", "NO"}, + new Object[]{"", "", "t2", "id", Types.INTEGER, "INTEGER", 64, 0, 10, 1, "", "", 0, 0, 64, 1, "NO", "YES"}, + new Object[]{"", "", "t2", "name", Types.VARCHAR, "VARCHAR", 64, 0, 10, 0, "", "", 0, 0, 64, 2, "YES", "NO"}, + new Object[]{"", "", "t1 inner join t2 on t1.id = t2.id", "id", Types.VARCHAR, "VARCHAR", 64, 0, 10, 0, "", "", 0, 0, 64, 2, "YES", "NO"}, + new Object[]{"", "", "t1 inner join t2 on t1.id = t2.id", "name", Types.VARCHAR, "VARCHAR", 64, 0, 10, 0, "", "", 0, 0, 64, 2, "YES", "NO"}, + }; + Object[][] indexMetas = new Object[][]{ + new Object[]{"PRIMARY", "id", false, "", 3, 1, "A", 34}, + }; + Object[][] beforeReturnValue = new Object[][]{ + new Object[]{1, "Tom"}, + }; + StatementProxy beforeMockStatementProxy = mockStatementProxy(returnValueColumnLabels, beforeReturnValue, columnMetas, indexMetas); + String sql = "update t1 inner join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL'"; + List asts = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL); + MySQLUpdateRecognizer recognizer = new MySQLUpdateRecognizer(sql, asts.get(0)); + UpdateExecutor mySQLUpdateJoinExecutor = new MySQLUpdateJoinExecutor(beforeMockStatementProxy, (statement, args) -> { + return null; + }, recognizer); + TableRecords beforeImage = mySQLUpdateJoinExecutor.beforeImage(); + Object[][] afterReturnValue = new Object[][]{ + new Object[]{1, "WILL"}, + }; + StatementProxy afterMockStatementProxy = mockStatementProxy(returnValueColumnLabels, afterReturnValue, columnMetas, indexMetas); + mySQLUpdateJoinExecutor.statementProxy = afterMockStatementProxy; + TableRecords afterImage = mySQLUpdateJoinExecutor.afterImage(beforeImage); + Assertions.assertDoesNotThrow(()->mySQLUpdateJoinExecutor.prepareUndoLog(beforeImage, afterImage)); + } + + private StatementProxy mockStatementProxy(List returnValueColumnLabels, Object[][] returnValue, Object[][] columnMetas, Object[][] indexMetas) { + MockDriver mockDriver = new MockDriver(returnValueColumnLabels, returnValue, columnMetas, indexMetas); + DruidDataSource dataSource = new DruidDataSource(); + dataSource.setUrl("jdbc:mock:xxx"); + dataSource.setDriver(mockDriver); + + DataSourceProxy dataSourceProxy = new DataSourceProxy(dataSource); + try { + Field field = dataSourceProxy.getClass().getDeclaredField("dbType"); + field.setAccessible(true); + field.set(dataSourceProxy, "mysql"); + ConnectionProxy connectionProxy = new ConnectionProxy(dataSourceProxy, dataSource.getConnection().getConnection()); + MockStatementBase mockStatement = new MockStatement(dataSource.getConnection().getConnection()); + return new StatementProxy(connectionProxy, mockStatement); + } catch (Exception e) { + throw new RuntimeException("init failed"); + } + } +} diff --git a/rm-datasource/src/test/java/io/seata/rm/datasource/mock/MockDatabaseMetaData.java b/rm-datasource/src/test/java/io/seata/rm/datasource/mock/MockDatabaseMetaData.java index ed2bf06afb6..888b0d0b593 100644 --- a/rm-datasource/src/test/java/io/seata/rm/datasource/mock/MockDatabaseMetaData.java +++ b/rm-datasource/src/test/java/io/seata/rm/datasource/mock/MockDatabaseMetaData.java @@ -20,6 +20,7 @@ import java.sql.ResultSet; import java.sql.RowIdLifetime; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -724,8 +725,17 @@ public ResultSet getTableTypes() throws SQLException { @Override public ResultSet getColumns(String catalog, String schemaPattern, String tableNamePattern, String columnNamePattern) throws SQLException { - return new MockResultSet((MockStatementBase)this.connection.createStatement()) - .mockResultSet(columnMetaColumnLabels, columnsMetasReturnValue); + List metas = new ArrayList<>(); + for (Object[] meta : columnsMetasReturnValue) { + if (tableNamePattern.equals(meta[2].toString())) { + metas.add(meta); + } + } + if(metas.isEmpty()){ + metas = Arrays.asList(columnsMetasReturnValue); + } + return new MockResultSet((MockStatementBase) this.connection.createStatement()) + .mockResultSet(columnMetaColumnLabels, metas.toArray(new Object[0][])); } @Override diff --git a/rm-datasource/src/test/java/io/seata/rm/datasource/mock/MockExecuteHandlerImpl.java b/rm-datasource/src/test/java/io/seata/rm/datasource/mock/MockExecuteHandlerImpl.java index 759a0f42d29..1e5ea57595d 100644 --- a/rm-datasource/src/test/java/io/seata/rm/datasource/mock/MockExecuteHandlerImpl.java +++ b/rm-datasource/src/test/java/io/seata/rm/datasource/mock/MockExecuteHandlerImpl.java @@ -17,14 +17,23 @@ import java.sql.ResultSet; import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; +import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock; +import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.ast.statement.SQLSelectStatement; +import io.seata.sqlparser.druid.mysql.MySQLSelectForUpdateRecognizer; +import io.seata.sqlparser.util.JdbcConstants; import com.alibaba.druid.mock.MockStatementBase; import com.alibaba.druid.mock.handler.MockExecuteHandler; /** - * @author will - */ + * @author will + */ public class MockExecuteHandlerImpl implements MockExecuteHandler { /** @@ -55,11 +64,33 @@ public MockExecuteHandlerImpl(List mockReturnValueColumnLabels, Object[] @Override public ResultSet executeQuery(MockStatementBase statement, String sql) throws SQLException { MockResultSet resultSet = new MockResultSet(statement); - //mock the return value resultSet.mockResultSet(mockReturnValueColumnLabels, mockReturnValue); //mock the rs meta data - resultSet.mockResultSetMetaData(mockColumnsMetasReturnValue); + List asts = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL); + List metas = new ArrayList<>(); + if(asts.get(0) instanceof SQLSelectStatement) { + SQLSelectStatement ast = (SQLSelectStatement) asts.get(0); + SQLSelectQueryBlock queryBlock = ast.getSelect().getQueryBlock(); + String tableName = ""; + if (queryBlock.getFrom() instanceof SQLExprTableSource) { + MySQLSelectForUpdateRecognizer recognizer = new MySQLSelectForUpdateRecognizer(sql, ast); + tableName = recognizer.getTableName(); + } else { + //select * from t inner join t1... + tableName = queryBlock.getFrom().toString(); + } + for (Object[] meta : mockColumnsMetasReturnValue) { + if (tableName.equalsIgnoreCase(meta[2].toString())) { + metas.add(meta); + } + } + } + if(metas.isEmpty()){ + //eg:select * from dual + metas = Arrays.asList(mockColumnsMetasReturnValue); + } + resultSet.mockResultSetMetaData(metas.toArray(new Object[0][])); return resultSet; } } diff --git a/sqlparser/seata-sqlparser-core/src/main/java/io/seata/sqlparser/SQLType.java b/sqlparser/seata-sqlparser-core/src/main/java/io/seata/sqlparser/SQLType.java index fd0c0955efe..520fc87748e 100644 --- a/sqlparser/seata-sqlparser-core/src/main/java/io/seata/sqlparser/SQLType.java +++ b/sqlparser/seata-sqlparser-core/src/main/java/io/seata/sqlparser/SQLType.java @@ -211,7 +211,11 @@ public enum SQLType { /** * Insert on duplicate update sql type. */ - INSERT_ON_DUPLICATE_UPDATE(102); + INSERT_ON_DUPLICATE_UPDATE(102), + /** + * update join sql type + */ + UPDATE_JOIN(103); private int i; diff --git a/sqlparser/seata-sqlparser-core/src/main/java/io/seata/sqlparser/SQLUpdateRecognizer.java b/sqlparser/seata-sqlparser-core/src/main/java/io/seata/sqlparser/SQLUpdateRecognizer.java index ad2dd6779c1..fa97f0416ff 100644 --- a/sqlparser/seata-sqlparser-core/src/main/java/io/seata/sqlparser/SQLUpdateRecognizer.java +++ b/sqlparser/seata-sqlparser-core/src/main/java/io/seata/sqlparser/SQLUpdateRecognizer.java @@ -23,6 +23,8 @@ * @author sharajava */ public interface SQLUpdateRecognizer extends WhereRecognizer { + String MULTI_TABLE_NAME_SEPERATOR = "#"; + /** * Gets update columns. @@ -38,6 +40,15 @@ public interface SQLUpdateRecognizer extends WhereRecognizer { */ List getUpdateValues(); + /** + * Gets update join item table name + * @param tableName the update join item table source name + * @return the update join item table alias name + */ + default String getTableAlias(String tableName) { + return null; + } + /** * Gets update columns is Simplified. * diff --git a/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/BaseRecognizer.java b/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/BaseRecognizer.java index c7f84f0ac97..ed4365ce156 100644 --- a/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/BaseRecognizer.java +++ b/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/BaseRecognizer.java @@ -26,7 +26,6 @@ import com.alibaba.druid.sql.ast.expr.SQLInSubQueryExpr; import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; import com.alibaba.druid.sql.ast.statement.SQLInsertStatement; -import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource; import com.alibaba.druid.sql.ast.statement.SQLMergeStatement; import com.alibaba.druid.sql.ast.statement.SQLReplaceStatement; import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource; @@ -120,13 +119,6 @@ public String getOriginalSQL() { @Override public boolean isSqlSyntaxSupports() { SQLASTVisitor visitor = new SQLASTVisitorAdapter() { - @Override - public boolean visit(SQLJoinTableSource x) { - //just like: UPDATE table a INNER JOIN table b ON a.id = b.pid ... - throw new NotSupportYetException("not support the sql syntax with join table:" + x - + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); - } - @Override public boolean visit(SQLInSubQueryExpr x) { //just like: ...where id in (select id from t) diff --git a/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/mysql/MySQLUpdateRecognizer.java b/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/mysql/MySQLUpdateRecognizer.java index 39e71094f2b..c3c747b98dc 100644 --- a/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/mysql/MySQLUpdateRecognizer.java +++ b/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/mysql/MySQLUpdateRecognizer.java @@ -16,7 +16,10 @@ package io.seata.sqlparser.druid.mysql; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; + import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLLimit; import com.alibaba.druid.sql.ast.SQLOrderBy; @@ -31,11 +34,12 @@ import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlUpdateStatement; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor; -import io.seata.common.exception.NotSupportYetException; import io.seata.sqlparser.util.ColumnUtils; import io.seata.sqlparser.ParametersHolder; import io.seata.sqlparser.SQLType; import io.seata.sqlparser.SQLUpdateRecognizer; +import io.seata.common.exception.NotSupportYetException; +import io.seata.common.exception.ShouldNeverHappenException; /** * The type My sql update recognizer. @@ -44,7 +48,9 @@ */ public class MySQLUpdateRecognizer extends BaseMySQLRecognizer implements SQLUpdateRecognizer { - private MySqlUpdateStatement ast; + private final MySqlUpdateStatement ast; + + private final Map tableName2AliasMap = new HashMap<>(4); /** * Instantiates a new My sql update recognizer. @@ -59,7 +65,14 @@ public MySQLUpdateRecognizer(String originalSQL, SQLStatement ast) { @Override public SQLType getSQLType() { - return SQLType.UPDATE; + SQLTableSource tableSource = this.ast.getTableSource(); + if (tableSource instanceof SQLExprTableSource) { + return SQLType.UPDATE; + } else if (tableSource instanceof SQLJoinTableSource) { + return SQLType.UPDATE_JOIN; + } else { + throw new NotSupportYetException("not support update table source with unknow"); + } } @Override @@ -129,30 +142,29 @@ public String getTableAlias() { @Override public String getTableName() { - StringBuilder sb = new StringBuilder(); - MySqlOutputVisitor visitor = new MySqlOutputVisitor(sb) { - - @Override - public boolean visit(SQLExprTableSource x) { - printTableSourceExpr(x.getExpr()); - return false; - } - - @Override - public boolean visit(SQLJoinTableSource x) { - throw new NotSupportYetException("not support the syntax of update with join table"); - } - }; - - SQLTableSource tableSource = ast.getTableSource(); + SQLTableSource tableSource = this.ast.getTableSource(); if (tableSource instanceof SQLExprTableSource) { - visitor.visit((SQLExprTableSource) tableSource); + return visitTableName((SQLExprTableSource) tableSource); } else if (tableSource instanceof SQLJoinTableSource) { - visitor.visit((SQLJoinTableSource) tableSource); + //update join sql,like update t1 inner join t2 on t1.id = t2.id set name = ?, age = ? + final int minTableNum = 2; + StringBuilder joinTables = new StringBuilder(); + joinTables.append(tableSource.toString()); + tableName2AliasMap.put(tableSource.toString(), tableSource.getAlias()); + this.getTableNames(tableSource, joinTables); + if (joinTables.toString().split(MULTI_TABLE_NAME_SEPERATOR).length < minTableNum + 1) { + throw new ShouldNeverHappenException("should get at least two table name for update join table source:" + tableSource.toString()); + } + //will return union table view name and single table names which linked by "#", like t1 inner join t2 on t1.id = t2.id#t1#t2 + return joinTables.toString(); } else { throw new NotSupportYetException("not support the syntax of update with unknow"); } - return sb.toString(); + } + + @Override + public String getTableAlias(String tableName) { + return tableName2AliasMap.get(tableName); } @Override @@ -183,4 +195,48 @@ public String getOrderByCondition(ParametersHolder parametersHolder, ArrayList updateSetItems = x.getItems(); + for (SQLUpdateSetItem updateSetItem : updateSetItems) { + if (updateSetItem.getValue() instanceof SQLQueryExpr) { + //just like: "update a set a.id = (select id from b where a.pid = b.pid)" + throw new NotSupportYetException("not support the sql syntax with join table:" + x + + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); + } + } + return true; + } + @Override public boolean visit(SQLInSubQueryExpr x) { //just like: ...where id in (select id from t) @@ -135,7 +156,7 @@ public boolean visit(SQLInSubQueryExpr x) { @Override public boolean visit(OracleSelectSubqueryTableSource x) { - //just like: select * from (select * from t) + //just like: select * from (select * from t) for update throw new NotSupportYetException("not support the sql syntax with SubQuery:" + x + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); } diff --git a/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/postgresql/BasePostgresqlRecognizer.java b/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/postgresql/BasePostgresqlRecognizer.java index 10cfd54d2e9..f44307a1544 100644 --- a/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/postgresql/BasePostgresqlRecognizer.java +++ b/sqlparser/seata-sqlparser-druid/src/main/java/io/seata/sqlparser/druid/postgresql/BasePostgresqlRecognizer.java @@ -18,8 +18,17 @@ import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLLimit; import com.alibaba.druid.sql.ast.SQLOrderBy; +import com.alibaba.druid.sql.ast.expr.SQLInSubQueryExpr; import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr; +import com.alibaba.druid.sql.ast.statement.SQLMergeStatement; +import com.alibaba.druid.sql.ast.statement.SQLReplaceStatement; +import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource; +import com.alibaba.druid.sql.dialect.postgresql.ast.stmt.PGInsertStatement; +import com.alibaba.druid.sql.dialect.postgresql.ast.stmt.PGUpdateStatement; +import com.alibaba.druid.sql.dialect.postgresql.visitor.PGASTVisitor; +import com.alibaba.druid.sql.dialect.postgresql.visitor.PGASTVisitorAdapter; import com.alibaba.druid.sql.dialect.postgresql.visitor.PGOutputVisitor; +import io.seata.common.exception.NotSupportYetException; import io.seata.common.util.StringUtils; import io.seata.sqlparser.ParametersHolder; import io.seata.sqlparser.druid.BaseRecognizer; @@ -67,6 +76,62 @@ public boolean visit(SQLVariantRefExpr x) { return visitor; } + @Override + public boolean isSqlSyntaxSupports() { + PGASTVisitor visitor = new PGASTVisitorAdapter() { + + @Override + public boolean visit(SQLSubqueryTableSource x) { + //just like: select * from (select * from t) for update + throw new NotSupportYetException("not support the sql syntax with SubQuery:" + x + + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); + } + + @Override + public boolean visit(PGUpdateStatement x) { + if (x.getFrom() != null) { + //just like: update a set id = b.pid from b where a.id = b.id + throw new NotSupportYetException("not support the sql syntax with join table:" + x + + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); + } + return true; + } + + @Override + public boolean visit(SQLInSubQueryExpr x) { + //just like: ...where id in (select id from t) + throw new NotSupportYetException("not support the sql syntax with InSubQuery:" + x + + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); + } + + @Override + public boolean visit(SQLReplaceStatement x) { + //just like: replace into t (id,dr) values (1,'2'), (2,'3') + throw new NotSupportYetException("not support the sql syntax with ReplaceStatement:" + x + + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); + } + + @Override + public boolean visit(SQLMergeStatement x) { + //just like: merge into ... WHEN MATCHED THEN ... + throw new NotSupportYetException("not support the sql syntax with MergeStatement:" + x + + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); + } + + @Override + public boolean visit(PGInsertStatement x) { + if (null != x.getQuery()) { + //just like: insert into t select * from t1 + throw new NotSupportYetException("not support the sql syntax insert with query:" + x + + "\nplease see the doc about SQL restrictions https://seata.io/zh-cn/docs/user/sqlreference/dml.html"); + } + return true; + } + }; + getAst().accept(visitor); + return true; + } + public String getWhereCondition(SQLExpr where, final ParametersHolder parametersHolder, final ArrayList> paramAppenderList) { if (Objects.isNull(where)) { diff --git a/sqlparser/seata-sqlparser-druid/src/test/java/io/seata/sqlparser/druid/DruidSQLRecognizerFactoryTest.java b/sqlparser/seata-sqlparser-druid/src/test/java/io/seata/sqlparser/druid/DruidSQLRecognizerFactoryTest.java index 81243be33ba..b4397fb7323 100644 --- a/sqlparser/seata-sqlparser-druid/src/test/java/io/seata/sqlparser/druid/DruidSQLRecognizerFactoryTest.java +++ b/sqlparser/seata-sqlparser-druid/src/test/java/io/seata/sqlparser/druid/DruidSQLRecognizerFactoryTest.java @@ -15,8 +15,6 @@ */ package io.seata.sqlparser.druid; -import com.alibaba.druid.sql.SQLUtils; -import com.alibaba.druid.sql.ast.SQLStatement; import io.seata.common.exception.NotSupportYetException; import io.seata.common.loader.EnhancedServiceLoader; import io.seata.sqlparser.SQLRecognizer; @@ -45,49 +43,51 @@ public void testSqlRecognizerCreation() { Assertions.assertNotNull(recognizerFactory.create(sql, JdbcConstants.ORACLE)); Assertions.assertNotNull(recognizerFactory.create(sql, JdbcConstants.POSTGRESQL)); - String sql2 = "update table a inner join table b on a.id = b.pid set a.name = ?"; - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql2, JdbcConstants.MYSQL)); + String sql1 = "update a set a.id = (select id from b where a.pid = b.pid)"; + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql1, JdbcConstants.ORACLE)); + String sql2 = "update (select a.id,a.name from a inner join b on a.id = b.id) t set t.name = 'xxx'"; Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql2, JdbcConstants.ORACLE)); - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql2, JdbcConstants.POSTGRESQL)); - - String sql3 = "update t set id = 1 where id in (select id from b)"; - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql3, JdbcConstants.MYSQL)); - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql3, JdbcConstants.ORACLE)); + String sql3 = "update a set id = b.pid from b where a.id = b.id"; Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql3, JdbcConstants.POSTGRESQL)); - String sql4 = "insert into a values (1, 2)"; - Assertions.assertNotNull(recognizerFactory.create(sql4, JdbcConstants.MYSQL)); - Assertions.assertNotNull(recognizerFactory.create(sql4, JdbcConstants.ORACLE)); - Assertions.assertNotNull(recognizerFactory.create(sql4, JdbcConstants.POSTGRESQL)); + String sql4 = "update t set id = 1 where id in (select id from b)"; + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql4, JdbcConstants.MYSQL)); + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql4, JdbcConstants.ORACLE)); + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql4, JdbcConstants.POSTGRESQL)); - String sql5 = "insert into a (id, name) values (1, 2), (3, 4)"; + String sql5 = "insert into a values (1, 2)"; Assertions.assertNotNull(recognizerFactory.create(sql5, JdbcConstants.MYSQL)); Assertions.assertNotNull(recognizerFactory.create(sql5, JdbcConstants.ORACLE)); Assertions.assertNotNull(recognizerFactory.create(sql5, JdbcConstants.POSTGRESQL)); - String sql6 = "insert into a select * from b"; - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql6, JdbcConstants.MYSQL)); - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql6, JdbcConstants.ORACLE)); - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql6, JdbcConstants.POSTGRESQL)); + String sql6 = "insert into a (id, name) values (1, 2), (3, 4)"; + Assertions.assertNotNull(recognizerFactory.create(sql6, JdbcConstants.MYSQL)); + Assertions.assertNotNull(recognizerFactory.create(sql6, JdbcConstants.ORACLE)); + Assertions.assertNotNull(recognizerFactory.create(sql6, JdbcConstants.POSTGRESQL)); + + String sql7 = "insert into a select * from b"; + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql7, JdbcConstants.MYSQL)); + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql7, JdbcConstants.ORACLE)); + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql7, JdbcConstants.POSTGRESQL)); - String sql7 = "delete from t where id = ?"; - Assertions.assertNotNull(recognizerFactory.create(sql7, JdbcConstants.MYSQL)); - Assertions.assertNotNull(recognizerFactory.create(sql7, JdbcConstants.ORACLE)); - Assertions.assertNotNull(recognizerFactory.create(sql7, JdbcConstants.POSTGRESQL)); + String sql8 = "delete from t where id = ?"; + Assertions.assertNotNull(recognizerFactory.create(sql8, JdbcConstants.MYSQL)); + Assertions.assertNotNull(recognizerFactory.create(sql8, JdbcConstants.ORACLE)); + Assertions.assertNotNull(recognizerFactory.create(sql8, JdbcConstants.POSTGRESQL)); - String sql8 = "delete from t where id in (select id from b)"; - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql8, JdbcConstants.MYSQL)); - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql8, JdbcConstants.ORACLE)); - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql8, JdbcConstants.POSTGRESQL)); + String sql9 = "delete from t where id in (select id from b)"; + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql9, JdbcConstants.MYSQL)); + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql9, JdbcConstants.ORACLE)); + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql9, JdbcConstants.POSTGRESQL)); - String sql9 = "select * from t for update"; - Assertions.assertNotNull(recognizerFactory.create(sql9, JdbcConstants.MYSQL)); - Assertions.assertNotNull(recognizerFactory.create(sql9, JdbcConstants.ORACLE)); - Assertions.assertNotNull(recognizerFactory.create(sql9, JdbcConstants.POSTGRESQL)); + String sql10 = "select * from t for update"; + Assertions.assertNotNull(recognizerFactory.create(sql10, JdbcConstants.MYSQL)); + Assertions.assertNotNull(recognizerFactory.create(sql10, JdbcConstants.ORACLE)); + Assertions.assertNotNull(recognizerFactory.create(sql10, JdbcConstants.POSTGRESQL)); - String sql10 = "select * from (select * from t) for update"; - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql10, JdbcConstants.MYSQL)); - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql10, JdbcConstants.ORACLE)); - Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql10, JdbcConstants.POSTGRESQL)); + String sql11 = "select * from (select * from t) for update"; + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql11, JdbcConstants.MYSQL)); + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql11, JdbcConstants.ORACLE)); + Assertions.assertThrows(NotSupportYetException.class, () -> recognizerFactory.create(sql11, JdbcConstants.POSTGRESQL)); } } diff --git a/sqlparser/seata-sqlparser-druid/src/test/java/io/seata/sqlparser/druid/MySQLUpdateRecognizerTest.java b/sqlparser/seata-sqlparser-druid/src/test/java/io/seata/sqlparser/druid/MySQLUpdateRecognizerTest.java index e8c7c2039a9..5ba096a84a0 100644 --- a/sqlparser/seata-sqlparser-druid/src/test/java/io/seata/sqlparser/druid/MySQLUpdateRecognizerTest.java +++ b/sqlparser/seata-sqlparser-druid/src/test/java/io/seata/sqlparser/druid/MySQLUpdateRecognizerTest.java @@ -350,6 +350,15 @@ public void testGetTableAlias() { Assertions.assertNull(recognizer.getTableAlias()); } + @Test + public void testUpdateJoinSql() { + String sql = "update t1 inner join t2 on t1.id = t2.id set name = ?, age = ?"; + List asts = SQLUtils.parseStatements(sql, JdbcConstants.MYSQL); + MySQLUpdateRecognizer recognizer = new MySQLUpdateRecognizer(sql, asts.get(0)); + String tableName = recognizer.getTableName(); + Assertions.assertEquals("t1 INNER JOIN t2 ON t1.id = t2.id#t1#t2",tableName); + } + @Override public String getDbType() { return JdbcConstants.MYSQL; diff --git a/test/pom.xml b/test/pom.xml index 0361c706954..39031194048 100644 --- a/test/pom.xml +++ b/test/pom.xml @@ -125,6 +125,12 @@ seata-serializer-seata ${project.version} + + + ${project.groupId} + seata-sqlparser-druid + ${project.version} + diff --git a/test/src/test/java/io/seata/at/mysql/MysqlUpdateJoinTest.java b/test/src/test/java/io/seata/at/mysql/MysqlUpdateJoinTest.java new file mode 100644 index 00000000000..6bd9820b846 --- /dev/null +++ b/test/src/test/java/io/seata/at/mysql/MysqlUpdateJoinTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 1999-2019 Seata.io Group. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.seata.at.mysql; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.Statement; + +import com.alibaba.druid.pool.DruidDataSource; +import com.alibaba.druid.util.JdbcUtils; +import io.seata.core.context.RootContext; +import io.seata.core.exception.TransactionException; +import io.seata.core.model.BranchStatus; +import io.seata.core.model.BranchType; +import io.seata.rm.DefaultResourceManager; +import io.seata.rm.datasource.DataCompareUtils; +import io.seata.rm.datasource.DataSourceManager; +import io.seata.rm.datasource.DataSourceProxy; +import io.seata.rm.datasource.sql.struct.TableMeta; +import io.seata.rm.datasource.sql.struct.TableMetaCacheFactory; +import io.seata.rm.datasource.sql.struct.TableRecords; +import io.seata.server.UUIDGenerator; +import io.seata.sqlparser.util.JdbcConstants; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + + +/** + * @author renliangyu857 + */ +public class MysqlUpdateJoinTest { + private static final int testRecordId = 1; + private static final int testRecordId1 = 2; + private static final long testTid = UUIDGenerator.generateUUID(); + private static final String mockXid = "127.0.0.1:8091:" + testTid; + private static final long mockBranchId = testTid + 1; + + private static final String mysql_jdbcUrl = "jdbc:mysql://127.0.0.1:3306/seata"; + private static final String mysql_username = "demo"; + private static final String mysql_password = "demo"; + private static final String mysql_driverClassName = JdbcUtils.MYSQL_DRIVER; + + + @Test + @Disabled + public void testUpdateJoin() throws Throwable { + doTestPhase2(false, "update t inner join t1 on t.a = t1.a set b = 3,t.c=3"); + System.out.println("AT MODE Phase2 test for update join looks good!"); + } + + private static void doPrepareData(String prepareSql) throws Throwable { + // init DataSource: helper + DruidDataSource helperDS = createNewDruidDataSource(); + + // prepare data for test: make sure no test record there + Connection helperConn = helperDS.getConnection(); + Statement helperStat = helperConn.createStatement(); + helperStat.execute(prepareSql); + helperStat.close(); + helperConn.close(); + } + + + private void doTestPhase2(boolean globalCommit, String updateSql) throws Throwable { + // init DataSource: helper + DruidDataSource helperDS = createNewDruidDataSource(); + + Connection helperConn = null; + Statement helperStat = null; + ResultSet table1HelperRes = null; + ResultSet table2HelperRes = null; + + initRM(); + + final DataSourceProxy dataSourceProxy = new DataSourceProxy(createNewDruidDataSource()); + + RootContext.bind(mockXid); + Connection testConn = dataSourceProxy.getConnection(); + Statement testStat = testConn.createStatement(); + + // >>> query before image + helperConn = helperDS.getConnection(); + helperStat = helperConn.createStatement(); + table1HelperRes = helperStat.executeQuery("select * from t where id = " + testRecordId ); + TableMeta table1Meta = TableMetaCacheFactory.getTableMetaCache(JdbcConstants.MYSQL).getTableMeta(dataSourceProxy.getPlainConnection(), + "t", dataSourceProxy.getResourceId()); + TableRecords table1BeforeImage = TableRecords.buildRecords(table1Meta, table1HelperRes); + table2HelperRes = helperStat.executeQuery("select * from t1 where id = " + testRecordId1); + TableMeta table2Meta = TableMetaCacheFactory.getTableMetaCache(JdbcConstants.MYSQL).getTableMeta(dataSourceProxy.getPlainConnection(), + "t1", dataSourceProxy.getResourceId()); + TableRecords table2BeforeImage = TableRecords.buildRecords(table2Meta, table2HelperRes); + // >>> update record should not throw exception + Assertions.assertDoesNotThrow(() -> testStat.execute(updateSql)); + // >>> close the statement and connection + testStat.close(); + testConn.close(); + RootContext.unbind(); + + if (globalCommit) { + // >>> Global Tx Phase 2: commit should not throw exception + Assertions.assertDoesNotThrow(() -> DefaultResourceManager.get().branchCommit(dataSourceProxy.getBranchType(), mockXid, mockBranchId, + dataSourceProxy.getResourceId(), null)); + } else { + DefaultResourceManager.get().branchRollback(dataSourceProxy.getBranchType(), mockXid, mockBranchId, dataSourceProxy.getResourceId(), null); + // >>> Global Tx Phase 2: rollback have a check,rollbacked record must equal to before image + helperConn = helperDS.getConnection(); + helperStat = helperConn.createStatement(); + table1HelperRes = helperStat.executeQuery("select * from t where id = " + testRecordId); + TableRecords table1CurrentImage = TableRecords.buildRecords(table1Meta, table1HelperRes); + table2HelperRes = helperStat.executeQuery("select * from t1 where id = " + testRecordId1); + TableRecords table2CurrentImage = TableRecords.buildRecords(table2Meta, table2HelperRes); + Assertions.assertTrue(DataCompareUtils.isRecordsEquals(table1BeforeImage, table1CurrentImage).getResult()); + Assertions.assertTrue(DataCompareUtils.isRecordsEquals(table2BeforeImage, table2CurrentImage).getResult()); + table1HelperRes.close(); + table2HelperRes.close(); + helperStat.close(); + helperConn.close(); + } + } + + private void initRM() { + // init RM + DefaultResourceManager.get(); + // mock the RM of AT + DefaultResourceManager.mockResourceManager(BranchType.AT, new DataSourceManager() { + @Override + public Long branchRegister(BranchType branchType, String resourceId, String clientId, String xid, String applicationData, String lockKeys) throws TransactionException { + return mockBranchId; + } + + @Override + public void branchReport(BranchType branchType, String xid, long branchId, BranchStatus status, String applicationData) throws TransactionException { + } + }); + + } + + private static DruidDataSource createNewDruidDataSource() throws Throwable { + DruidDataSource druidDataSource = new DruidDataSource(); + initDruidDataSource(druidDataSource); + return druidDataSource; + } + + private static void initDruidDataSource(DruidDataSource druidDataSource) throws Throwable { + druidDataSource.setDbType(JdbcConstants.MYSQL); + druidDataSource.setUrl(mysql_jdbcUrl); + druidDataSource.setUsername(mysql_username); + druidDataSource.setPassword(mysql_password); + druidDataSource.setDriverClassName(mysql_driverClassName); + druidDataSource.init(); + } +} \ No newline at end of file diff --git a/test/src/test/resources/README.md b/test/src/test/resources/README.md new file mode 100644 index 00000000000..c162008b3bb --- /dev/null +++ b/test/src/test/resources/README.md @@ -0,0 +1,21 @@ +## MySQLUpdateJoinTest +##测试表结构 +CREATE TABLE `t` ( +`id` int NOT NULL, +`a` int DEFAULT NULL, +`c` int DEFAULT NULL, +PRIMARY KEY (`id`), +KEY `a` (`a`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + +CREATE TABLE `t1` ( +`id` int NOT NULL, +`a` int DEFAULT NULL, +`b` int DEFAULT NULL, +`c` int DEFAULT NULL, +PRIMARY KEY (`id`), +KEY `a` (`a`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +###测试数据 +insert into t(id,a,c) values(1,1,1);\ +insert into t1(id,a,b,c) values(2,1,2,2) \ No newline at end of file