├── .classpath ├── .gitignore ├── .project ├── .settings ├── org.eclipse.core.resources.prefs └── org.scala-ide.sdt.core.prefs ├── LICENSE ├── MANIFEST.MF.prototype ├── README.md ├── build.properties ├── build.sbt ├── patch-ide.bash ├── project ├── build.properties └── build.sbt └── src ├── main ├── java │ └── scala │ │ └── tools │ │ └── refactoring │ │ └── JavadocStub.java ├── scala-2.10 │ ├── README │ └── scala │ │ └── tools │ │ └── refactoring │ │ ├── ScalaVersionsAdapter.scala │ │ └── implementations │ │ └── oimports │ │ └── ImplicitValDefTraverserPF.scala ├── scala-2.11 │ ├── README │ └── scala │ │ └── tools │ │ └── refactoring │ │ ├── ScalaVersionsAdapter.scala │ │ └── implementations │ │ └── oimports │ │ └── ImplicitValDefTraverserPF.scala ├── scala-2.12 │ ├── README │ └── scala │ │ └── tools │ │ └── refactoring │ │ ├── ScalaVersionsAdapter.scala │ │ └── implementations │ │ └── oimports │ │ └── ImplicitValDefTraverserPF.scala └── scala │ └── scala │ └── tools │ └── refactoring │ ├── MultiStageRefactoring.scala │ ├── ParameterlessRefactoring.scala │ ├── Refactoring.scala │ ├── analysis │ ├── CompilationUnitDependencies.scala │ ├── CompilationUnitIndexes.scala │ ├── GlobalIndexes.scala │ ├── ImportAnalysis.scala │ ├── ImportsToolbox.scala │ ├── Indexes.scala │ ├── NameValidation.scala │ ├── PartiallyAppliedMethodsFinder.scala │ ├── ScopeAnalysis.scala │ ├── SymbolExpanders.scala │ └── TreeAnalysis.scala │ ├── common │ ├── Change.scala │ ├── CompilerAccess.scala │ ├── CompilerApiExtensions.scala │ ├── EnrichedTrees.scala │ ├── InsertionPositions.scala │ ├── InteractiveScalaCompiler.scala │ ├── Occurrences.scala │ ├── PositionDebugging.scala │ ├── Selections.scala │ ├── TracingHelpers.scala │ ├── TreeExtractors.scala │ ├── TreeTraverser.scala │ ├── exceptions.scala │ ├── package.scala │ └── tracing.scala │ ├── implementations │ ├── AddField.scala │ ├── AddImportStatement.scala │ ├── AddMethod.scala │ ├── AddValOrDef.scala │ ├── ChangeParamOrder.scala │ ├── ClassParameterDrivenSourceGeneration.scala │ ├── ExpandCaseClassBinding.scala │ ├── ExplicitGettersSetters.scala │ ├── ExtractLocal.scala │ ├── ExtractMethod.scala │ ├── ExtractTrait.scala │ ├── GenerateHashcodeAndEquals.scala │ ├── ImportsHelper.scala │ ├── InlineLocal.scala │ ├── IntroduceProductNTrait.scala │ ├── MarkOccurrences.scala │ ├── MergeParameterLists.scala │ ├── MethodSignatureRefactoring.scala │ ├── MoveClass.scala │ ├── MoveConstructorToCompanionObject.scala │ ├── OrganizeImports.scala │ ├── Rename.scala │ ├── SplitParameterLists.scala │ ├── UnusedImportsFinder.scala │ ├── extraction │ │ ├── ExtractCode.scala │ │ ├── ExtractExtractor.scala │ │ ├── ExtractMethod.scala │ │ ├── ExtractParameter.scala │ │ ├── ExtractValue.scala │ │ └── ExtractionRefactoring.scala │ └── oimports │ │ ├── ImportParticipants.scala │ │ ├── ImportsOrganizer.scala │ │ ├── OrganizeImportsWorker.scala │ │ ├── Region.scala │ │ ├── RegionTransformations.scala │ │ └── TreeToolbox.scala │ ├── package.scala │ ├── sourcegen │ ├── AbstractPrinter.scala │ ├── CommentsUtils.scala │ ├── CommonPrintUtils.scala │ ├── Formatting.scala │ ├── Fragment.scala │ ├── Indentation.scala │ ├── Layout.scala │ ├── LayoutHelper.scala │ ├── PrettyPrinter.scala │ ├── Requisite.scala │ ├── ReusingPrinter.scala │ ├── SourceGenerator.scala │ ├── SourceUtils.scala │ ├── TreeChangesDiscoverer.scala │ └── TreePrintingTraversals.scala │ ├── transformation │ ├── TransformableSelections.scala │ ├── Transformations.scala │ ├── TreeFactory.scala │ └── TreeTransformations.scala │ └── util │ ├── CompilerProvider.scala │ ├── Memoized.scala │ ├── SourceHelpers.scala │ ├── SourceWithMarker.scala │ ├── SourceWithSelection.scala │ ├── UnionFind.scala │ └── UniqueNames.scala └── test ├── java └── scala │ └── tools │ └── refactoring │ ├── common │ └── TracingHelpersTest.scala │ └── tests │ └── util │ ├── ExceptionWrapper.java │ ├── ScalaVersion.java │ ├── ScalaVersionTestRule.java │ └── TestRules.java ├── scala-2.10 ├── README └── scala │ └── tools │ └── refactoring │ └── tests │ └── implementations │ └── imports │ ├── OrganizeImportsScalaSpecificTests.scala │ └── OrganizeImportsWithMacrosTest.scala ├── scala-2.11 ├── README └── scala │ └── tools │ └── refactoring │ └── tests │ └── implementations │ └── imports │ ├── OrganizeImportsScalaSpecificTests.scala │ └── OrganizeImportsWithMacrosTest.scala ├── scala-2.12 ├── README └── scala │ └── tools │ └── refactoring │ └── tests │ └── implementations │ └── imports │ ├── OrganizeImportsScalaSpecificTests.scala │ └── OrganizeImportsWithMacrosTest.scala └── scala └── scala └── tools └── refactoring ├── implementations └── OrganizeImportsAlgosTest.scala └── tests ├── RefactoringTestSuite.scala ├── analysis ├── CompilationUnitDependenciesTest.scala ├── DeclarationIndexTest.scala ├── FindShadowedTest.scala ├── ImportAnalysisTest.scala ├── MultipleFilesIndexTest.scala ├── NameValidationTest.scala ├── ScopeAnalysisTest.scala └── TreeAnalysisTest.scala ├── common ├── EnrichedTreesTest.scala ├── InsertionPositionsTest.scala ├── OccurrencesTest.scala ├── SelectionDependenciesTest.scala ├── SelectionExpansionsTest.scala ├── SelectionPropertiesTest.scala └── SelectionsTest.scala ├── implementations ├── AddFieldTest.scala ├── AddMethodTest.scala ├── ChangeParamOrderTest.scala ├── ExpandCaseClassBindingTest.scala ├── ExplicitGettersSettersTest.scala ├── ExtractLocalTest.scala ├── ExtractMethodTest.scala ├── ExtractTraitTest.scala ├── GenerateHashcodeAndEqualsTest.scala ├── InlineLocalTest.scala ├── IntroduceProductNTraitTest.scala ├── MarkOccurrencesTest.scala ├── MergeParameterListsTest.scala ├── MoveClassTest.scala ├── MoveConstructorToCompanionObjectTest.scala ├── RenameTest.scala ├── SplitParameterListsTest.scala ├── extraction │ ├── ExtractCodeTest.scala │ ├── ExtractExtractorTest.scala │ ├── ExtractMethodTest.scala │ ├── ExtractParameterTest.scala │ ├── ExtractValueTest.scala │ └── ExtractionsTest.scala └── imports │ ├── AddImportStatementTest.scala │ ├── OrganizeImportsBaseTest.scala │ ├── OrganizeImportsCollapseSelectorsToWildcardTest.scala │ ├── OrganizeImportsEndOfLineTest.scala │ ├── OrganizeImportsFullyRecomputeTest.scala │ ├── OrganizeImportsGroupsTest.scala │ ├── OrganizeImportsRecomputeAndModifyTest.scala │ ├── OrganizeImportsTest.scala │ ├── OrganizeImportsWildcardsTest.scala │ ├── OrganizeMissingImportsTest.scala │ ├── PrependOrDropScalaPackageFromRecomputedTest.scala │ ├── PrependOrDropScalaPackageKeepTest.scala │ └── UnusedImportsFinderTest.scala ├── sourcegen ├── CustomFormattingTest.scala ├── IndividualSourceGenTest.scala ├── LayoutTest.scala ├── PrettyPrinterTest.scala ├── ReusingPrinterTest.scala ├── SourceGenTest.scala ├── SourceHelperTest.scala ├── SourceUtilsTest.scala └── TreeChangesDiscovererTest.scala ├── transformation ├── TransformableSelectionTest.scala └── TreeTransformationsTest.scala └── util ├── FreshCompilerForeachTest.scala ├── SourceHelpersTest.scala ├── SourceWithMarkerTest.scala ├── TestHelper.scala ├── TestRefactoring.scala ├── TextSelections.scala ├── TextSelectionsTest.scala ├── UnionFindInitTest.scala └── UnionFindTest.scala /.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | bin 3 | .cache-main 4 | .cache-tests 5 | .attach_pid* 6 | **/*.tmpBin 7 | *.swp 8 | .ensime* 9 | .idea 10 | *.iml 11 | -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | org.scala-refactoring.library 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.pde.ManifestBuilder 10 | 11 | 12 | 13 | 14 | org.eclipse.pde.SchemaBuilder 15 | 16 | 17 | 18 | 19 | org.scala-ide.sdt.core.scalabuilder 20 | 21 | 22 | 23 | 24 | org.scalastyle.scalastyleplugin.core.ScalastyleBuilder 25 | 26 | 27 | 28 | 29 | 30 | org.scala-ide.sdt.core.scalanature 31 | org.eclipse.jdt.core.javanature 32 | org.eclipse.pde.PluginNature 33 | org.scalastyle.scalastyleplugin.core.ScalastyleNature 34 | 35 | 36 | -------------------------------------------------------------------------------- /.settings/org.eclipse.core.resources.prefs: -------------------------------------------------------------------------------- 1 | #Generated by sbteclipse 2 | #Thu Oct 13 18:48:15 CEST 2016 3 | encoding/=UTF-8 4 | -------------------------------------------------------------------------------- /.settings/org.scala-ide.sdt.core.prefs: -------------------------------------------------------------------------------- 1 | //src/main/scala=main 2 | //src/main/scala-2.12=main 3 | //src/test/java=tests 4 | //src/test/scala=tests 5 | //src/test/scala-2.12=tests 6 | P= 7 | Xcheckinit=false 8 | Xdisable-assertions=false 9 | Xelide-below=-2147483648 10 | Xexperimental=false 11 | Xfatal-warnings=true 12 | Xfuture=true 13 | Xlog-implicits=false 14 | Xno-uescape=false 15 | Xplugin= 16 | Xplugin-disable= 17 | Xplugin-require= 18 | Xpluginsdir=misc/scala-devel/plugins 19 | Ypresentation-debug=false 20 | Ypresentation-delay=0 21 | Ypresentation-log= 22 | Ypresentation-replay= 23 | Ypresentation-verbose=false 24 | Ywarn-dead-code=true 25 | apiDiff=false 26 | compileorder=Mixed 27 | deprecation=true 28 | eclipse.preferences.version=1 29 | explaintypes=false 30 | feature=true 31 | g=vars 32 | no-specialization=false 33 | nowarn=false 34 | optimise=false 35 | recompileOnMacroDef=true 36 | relationsDebug=false 37 | scala.compiler.additionalParams=-deprecation\:false -encoding UTF-8 -feature -language\:_ -Xlint\:-unused,_ -Yno-adapted-args -Ywarn-unused-import -Ywarn-unused\:imports,privates,locals,-patvars,-params,-implicits,_ 38 | scala.compiler.installation=2.12 39 | scala.compiler.sourceLevel=2.12 40 | scala.compiler.useProjectSettings=true 41 | stopBuildOnError=true 42 | target=jvm-1.8 43 | unchecked=true 44 | useScopesCompiler=true 45 | verbose=false 46 | withVersionClasspathValidator=false 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SCALA LICENSE 2 | 3 | Copyright (c) 2002-2010 EPFL, Lausanne, unless otherwise specified. 4 | All rights reserved. 5 | 6 | This software was developed by the Programming Methods Laboratory of the 7 | Swiss Federal Institute of Technology (EPFL), Lausanne, Switzerland. 8 | 9 | Permission to use, copy, modify, and distribute this software in source 10 | or binary form for any purpose with or without fee is hereby granted, 11 | provided that the following conditions are met: 12 | 13 | 1. Redistributions of source code must retain the above copyright 14 | notice, this list of conditions and the following disclaimer. 15 | 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in the 18 | documentation and/or other materials provided with the distribution. 19 | 20 | 3. Neither the name of the EPFL nor the names of its contributors 21 | may be used to endorse or promote products derived from this 22 | software without specific prior written permission. 23 | 24 | 25 | THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND 26 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 27 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 28 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE 29 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 30 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 31 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 32 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 33 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 34 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 35 | SUCH DAMAGE. 36 | -------------------------------------------------------------------------------- /MANIFEST.MF.prototype: -------------------------------------------------------------------------------- 1 | Manifest-Version: 1.0 2 | Bundle-ManifestVersion: 2 3 | Bundle-Name: Scala Refactoring 4 | Bundle-SymbolicName: org.scala-refactoring.library 5 | Bundle-Version: version.qualifier 6 | Require-Bundle: org.junit;bundle-version="4.11.0", 7 | org.scala-lang.scala-library 8 | Export-Package: scala.tools.refactoring,scala.tools.refactoring.analys 9 | is,scala.tools.refactoring.common,scala.tools.refactoring.implementat 10 | ions,scala.tools.refactoring.implementations.extraction,scala.tools.r 11 | efactoring.sourcegen,scala.tools.refactoring.transformation,scala.too 12 | ls.refactoring.util 13 | Bundle-ClassPath: . 14 | Import-Package: scala.reflect.internal;resolution:=optional, 15 | scala.reflect.internal.util;resolution:=optional, 16 | scala.reflect.api;resolution:=optional, 17 | scala.reflect.runtime;resolution:=optional, 18 | scala.reflect.macros;resolution:=optional, 19 | scala.reflect.io;resolution:=optional, 20 | scala.tools.nsc, 21 | scala.tools.nsc.ast, 22 | scala.tools.nsc.ast.parser, 23 | scala.tools.nsc.interactive, 24 | scala.tools.nsc.io, 25 | scala.tools.nsc.reporters, 26 | scala.tools.nsc.settings, 27 | scala.tools.nsc.symtab, 28 | scala.tools.nsc.typechecker, 29 | scala.tools.nsc.util 30 | Bundle-RequiredExecutionEnvironment: JavaSE-1.6 31 | -------------------------------------------------------------------------------- /build.properties: -------------------------------------------------------------------------------- 1 | bin.includes = META-INF/,\ 2 | .,\ 3 | src/main/,\ 4 | src/test/ 5 | src.includes = src/ 6 | jars.compile.order = . 7 | source.. = src/main/scala/,\ 8 | src/test/scala/,\ 9 | src/test/java/ 10 | output.. = bin/ 11 | 12 | -------------------------------------------------------------------------------- /patch-ide.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | SCRIPT_NAME="./$(basename "$0")" 5 | 6 | showHelp() { 7 | echo "Patches ScalaIDE with the latest local build" 8 | echo "--------------------------------------------" 9 | echo "" 10 | echo "The script is controlled by the following environment variables:" 11 | echo " SCALA_IDE_HOME (mandatory):" 12 | echo " Path to you local ScalaIDE installation" 13 | echo "" 14 | echo "Examples: " 15 | echo " SCALA_IDE_HOME=\"/path/to/scala-ide\" $SCRIPT_NAME" 16 | echo " SCALA_IDE_HOME=\"/path/to/scala-ide\" KEEP_REFACTORING_LIBRARY_BACKUP=true $SCRIPT_NAME" 17 | echo "" 18 | echo "Best practice:" 19 | echo " If you use the script regularly, it is recommended to export" 20 | echo " an appropriate value for SCALA_IDE_HOME via your bashrc, so that you" 21 | echo " don't have to specify this setting repeatedly." 22 | echo "" 23 | echo "Warning:" 24 | echo " Note that patching the IDE like this only works as long as" 25 | echo " binary compatibility is maintained. Watch out for" 26 | echo " AbstractMethodErrors and the like." 27 | } 28 | 29 | showHelpAndDie() { 30 | showHelp 31 | exit 1 32 | } 33 | 34 | echoErr() { 35 | cat <<< "$@" 1>&2 36 | } 37 | 38 | KEEP_REFACTORING_LIBRARY_BACKUP=${KEEP_REFACTORING_LIBRARY_BACKUP:-true} 39 | 40 | if [[ -z "$SCALA_IDE_HOME" ]]; then 41 | showHelpAndDie 42 | fi 43 | 44 | SCALA_IDE_PLUGINS_DIR="$SCALA_IDE_HOME/plugins" 45 | 46 | if [[ ! -d "$SCALA_IDE_PLUGINS_DIR" || ! -w "$SCALA_IDE_PLUGINS_DIR" ]]; then 47 | echoErr "Invalid SCALA_IDE_HOME: $SCALA_IDE_PLUGINS_DIR is not a writable directory" 48 | exit 1 49 | fi 50 | 51 | TARGET_FOLDER="./target/scala-2.11/" 52 | 53 | shopt -s nullglob 54 | _newRefactoringJars=("$TARGET_FOLDER"*SNAPSHOT.jar) 55 | NEW_REFACTORING_JAR="${_newRefactoringJars[0]}" 56 | 57 | if [[ ! -f "$NEW_REFACTORING_JAR" || ! -r "$NEW_REFACTORING_JAR" ]]; then 58 | echoErr "Cannot find a build of the library in $TARGET_FOLDER" 59 | exit 1 60 | fi 61 | 62 | TSTAMP="$(date +%Y-%m-%dT%H-%M-%S)" 63 | 64 | _oldRefactoringJar=($SCALA_IDE_PLUGINS_DIR/org.scala-refactoring.library*.jar) 65 | if [[ ${#_oldRefactoringJar[@]} == 0 ]]; then 66 | echoErr "Cannot find the refactoring library in $SCALA_IDE_PLUGINS_DIR" 67 | exit 1 68 | elif [[ ${#_oldRefactoringJar[@]} -gt 1 ]]; then 69 | echoErr "Multiple copies of the refactoring library found in $SCALA_IDE_PLUGINS_DIR:" 70 | for jarFile in "${_oldRefactoringJar[@]}"; do 71 | echoErr " $jarFile" 72 | done 73 | exit 1 74 | fi 75 | 76 | OLD_REFACTORING_JAR="${_oldRefactoringJar[0]}" 77 | BACKUP_REFACTORING_JAR="$OLD_REFACTORING_JAR.$TSTAMP.bak" 78 | cp "$OLD_REFACTORING_JAR" "$BACKUP_REFACTORING_JAR" 79 | cp -i "$NEW_REFACTORING_JAR" "$OLD_REFACTORING_JAR" 80 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.16 2 | -------------------------------------------------------------------------------- /project/build.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") 2 | 3 | // We need to be able to not add the scoverage plugin to the scala-refactoring build. 4 | // This is necessary because scala-refactoring is built during Scala PR CI, which means 5 | // that we can not rely on any plugins that depend on scalac. For normal scala-refactoring 6 | // builds this variable should not be set, the Scala PR CI however nedes to set it. 7 | if (sys.env.contains("OMIT_SCOVERAGE_PLUGIN")) 8 | Nil 9 | else 10 | List(addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0")) 11 | -------------------------------------------------------------------------------- /src/main/java/scala/tools/refactoring/JavadocStub.java: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring; 2 | 3 | /* 4 | * An empty class to create a fake javadoc.jar as required by Sonatype. 5 | * 6 | * */ 7 | public class JavadocStub { 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala-2.10/README: -------------------------------------------------------------------------------- 1 | Put sources specifically for Scala-2.10.x here -------------------------------------------------------------------------------- /src/main/scala-2.10/scala/tools/refactoring/ScalaVersionsAdapter.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | 3 | object ScalaVersionAdapters { 4 | trait CompilerApiAdapters { 5 | val global: scala.tools.nsc.Global 6 | import global._ 7 | 8 | def annotationInfoTree(info: AnnotationInfo): Tree = info.original 9 | 10 | def isImplementationArtifact(sym: Symbol): Boolean = { 11 | sym.isImplementationArtifact || { 12 | // Unfortunatley, for Scala-2.10, we can not rely on `isImplementationArtifact` 13 | // as this method might return wrong results. To mitigate this, we fall back to 14 | // the hack below: 15 | sym.name.toString.contains("$") 16 | } 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/scala-2.10/scala/tools/refactoring/implementations/oimports/ImplicitValDefTraverserPF.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations.oimports 3 | 4 | import scala.tools.nsc.Global 5 | 6 | class ImplicitValDefTraverserPF[G <: Global](val global: G) { 7 | import global._ 8 | 9 | /** Unsupported for Scala 2.10. */ 10 | def apply(traverser: Traverser): PartialFunction[Tree, Unit] = PartialFunction.empty 11 | } 12 | -------------------------------------------------------------------------------- /src/main/scala-2.11/README: -------------------------------------------------------------------------------- 1 | Put sources specifically for Scala-2.11.x here -------------------------------------------------------------------------------- /src/main/scala-2.11/scala/tools/refactoring/ScalaVersionsAdapter.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | 3 | object ScalaVersionAdapters { 4 | trait CompilerApiAdapters { 5 | val global: scala.tools.nsc.Global 6 | import global._ 7 | 8 | def annotationInfoTree(info: AnnotationInfo): Tree = info.tree 9 | 10 | def isImplementationArtifact(sym: Symbol): Boolean = { 11 | sym.isImplementationArtifact 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/main/scala-2.11/scala/tools/refactoring/implementations/oimports/ImplicitValDefTraverserPF.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations.oimports 3 | 4 | import scala.tools.nsc.Global 5 | 6 | class ImplicitValDefTraverserPF[G <: Global](val global: G) { 7 | import global._ 8 | import global.analyzer._ 9 | 10 | private def continueWithFunction(traverser: Traverser, rhs: Attachable) = rhs match { 11 | case TypeApply(fun, _) => 12 | traverser.traverse(fun) 13 | } 14 | 15 | def apply(traverser: Traverser): PartialFunction[Tree, Unit] = { 16 | case ValDef(_, _, _, rhs: Attachable) if rhs.hasAttachment[MacroExpansionAttachment] => 17 | val mea = rhs.attachments.get[MacroExpansionAttachment] 18 | mea.collect { 19 | case MacroExpansionAttachment(_, expanded: Typed) => 20 | expanded 21 | }.foreach { expanded => 22 | traverser.traverse(expanded.expr) 23 | } 24 | continueWithFunction(traverser, rhs) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala-2.12/README: -------------------------------------------------------------------------------- 1 | Put sources specifically for Scala-2.11.x here -------------------------------------------------------------------------------- /src/main/scala-2.12/scala/tools/refactoring/ScalaVersionsAdapter.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | 3 | object ScalaVersionAdapters { 4 | trait CompilerApiAdapters { 5 | val global: scala.tools.nsc.Global 6 | import global._ 7 | 8 | def annotationInfoTree(info: AnnotationInfo): Tree = info.tree 9 | 10 | def isImplementationArtifact(sym: Symbol): Boolean = { 11 | sym.isImplementationArtifact 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/main/scala-2.12/scala/tools/refactoring/implementations/oimports/ImplicitValDefTraverserPF.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations.oimports 3 | 4 | import scala.tools.nsc.Global 5 | 6 | class ImplicitValDefTraverserPF[G <: Global](val global: G) { 7 | import global._ 8 | import global.analyzer._ 9 | 10 | private def continueWithFunction(traverser: Traverser, rhs: Attachable) = rhs match { 11 | case TypeApply(fun, _) => 12 | traverser.traverse(fun) 13 | } 14 | 15 | def apply(traverser: Traverser): PartialFunction[Tree, Unit] = { 16 | case ValDef(_, _, _, rhs: Attachable) if rhs.hasAttachment[MacroExpansionAttachment] => 17 | val mea = rhs.attachments.get[MacroExpansionAttachment] 18 | mea.collect { 19 | case MacroExpansionAttachment(_, expanded: Typed) => 20 | expanded 21 | }.foreach { expanded => 22 | traverser.traverse(expanded.expr) 23 | } 24 | continueWithFunction(traverser, rhs) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/MultiStageRefactoring.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | 7 | import common.Change 8 | 9 | /** 10 | * The super class of all refactoring implementations, 11 | * representing the several phases of the refactoring 12 | * process. 13 | */ 14 | abstract class MultiStageRefactoring extends Refactoring { 15 | 16 | this: common.CompilerAccess => 17 | 18 | /** 19 | * Preparing a refactoring can either return a result 20 | * or an instance of PreparationError, describing the 21 | * cause why the refactoring cannot be performed. 22 | */ 23 | 24 | type PreparationResult 25 | 26 | case class PreparationError(cause: String) 27 | 28 | def prepare(s: Selection): Either[PreparationError, PreparationResult] 29 | 30 | /** 31 | * Refactorings are parameterized by the user, and to keep 32 | * them stateless, the result of the preparation step needs 33 | * to be passed to the perform method. 34 | * 35 | * The result can either be an error or a list of trees that 36 | * contain changes. 37 | */ 38 | 39 | type RefactoringParameters 40 | 41 | case class RefactoringError(cause: String) 42 | 43 | def perform(selection: Selection, prepared: PreparationResult, params: RefactoringParameters): Either[RefactoringError, List[Change]] 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/ParameterlessRefactoring.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | 3 | import common.Change 4 | 5 | /** 6 | * A helper trait for refactorings that don't take RefactoringParameters. 7 | * 8 | * With this trait, the refactoring can implement the simplified perform 9 | * method. 10 | */ 11 | trait ParameterlessRefactoring { 12 | 13 | this: MultiStageRefactoring => 14 | 15 | class RefactoringParameters 16 | 17 | def perform(selection: Selection, prepared: PreparationResult, params: RefactoringParameters): Either[RefactoringError, List[Change]] = { 18 | perform(selection, prepared) 19 | } 20 | 21 | def perform(selection: Selection, prepared: PreparationResult): Either[RefactoringError, List[Change]] 22 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/Refactoring.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | 7 | import scala.tools.nsc.io.AbstractFile 8 | import scala.tools.refactoring.common.Selections 9 | import scala.tools.refactoring.common.EnrichedTrees 10 | import scala.tools.refactoring.sourcegen.SourceGenerator 11 | import scala.tools.refactoring.transformation.TreeTransformations 12 | import scala.tools.refactoring.common.TextChange 13 | import scala.tools.refactoring.common.TracingImpl 14 | 15 | /** 16 | * The Refactoring trait combines the transformation and source generation traits with 17 | * their dependencies. Refactoring is mixed in by all concrete refactorings and can be 18 | * used by users of the library. 19 | */ 20 | trait Refactoring extends Selections with TreeTransformations with TracingImpl with SourceGenerator with EnrichedTrees { 21 | 22 | this: common.CompilerAccess => 23 | 24 | /** 25 | * Creates a list of changes from a list of (potentially changed) trees. 26 | * 27 | * @param A list of trees that are to be searched for modifications. 28 | * @return A list of changes that can be applied to the source file. 29 | */ 30 | def refactor(changed: List[global.Tree]): List[TextChange] = context("main") { 31 | val changes = createChanges(changed) 32 | changes map minimizeChange 33 | } 34 | 35 | /** 36 | * Creates changes by applying a transformation to the root tree of an 37 | * abstract file. 38 | */ 39 | def transformFile(file: AbstractFile, transformation: Transformation[global.Tree, global.Tree]): List[TextChange] = { 40 | refactor(transformation(abstractFileToTree(file)).toList) 41 | } 42 | 43 | /** 44 | * Creates changes by applying several transformations to the root tree 45 | * of an abstract file. 46 | * Each transformation creates a new root tree that is used as input of 47 | * the next transformation. 48 | */ 49 | def transformFile(file: AbstractFile, transformations: List[Transformation[global.Tree, global.Tree]]): List[TextChange] = { 50 | def inner(root: global.Tree, ts: List[Transformation[global.Tree, global.Tree]]): Option[global.Tree] = { 51 | ts match { 52 | case t :: rest => 53 | t(root) match { 54 | case Some(newRoot) => inner(newRoot, rest) 55 | case None => None 56 | } 57 | case Nil => Some(root) 58 | } 59 | } 60 | 61 | refactor(inner(abstractFileToTree(file), transformations).toList) 62 | } 63 | 64 | /** 65 | * Makes a generated change as small as possible by eliminating the 66 | * common pre- and suffix between the change and the source file. 67 | */ 68 | private def minimizeChange(change: TextChange): TextChange = change match { 69 | case TextChange(file, from, to, changeText) => 70 | 71 | def commonPrefixLength(s1: Seq[Char], s2: Seq[Char]) = 72 | (s1 zip s2 takeWhile Function.tupled(_ == _)).length 73 | 74 | val original = file.content.subSequence(from, to).toString 75 | val replacement = changeText 76 | 77 | val commonStart = commonPrefixLength(original, replacement) 78 | val commonEnd = commonPrefixLength(original.substring(commonStart).reverse, replacement.substring(commonStart).reverse) 79 | 80 | val minimizedChangeText = changeText.subSequence(commonStart, changeText.length - commonEnd).toString 81 | TextChange(file, from + commonStart, to - commonEnd, minimizedChangeText) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/analysis/CompilationUnitIndexes.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package analysis 7 | 8 | import collection.mutable.HashMap 9 | import collection.mutable.ListBuffer 10 | 11 | /** 12 | * A CompilationUnitIndex is a light-weight index that 13 | * holds all definitions and references in a compilation 14 | * unit. This index is built with the companion object, 15 | * which traverses the whole compilation unit once and 16 | * then memoizes all relations. 17 | * 18 | */ 19 | trait CompilationUnitIndexes { 20 | 21 | this: common.EnrichedTrees with common.CompilerAccess with common.TreeTraverser => 22 | 23 | import global._ 24 | 25 | trait CompilationUnitIndex { 26 | def root: Tree 27 | def definitions: Map[Symbol, List[DefTree]] 28 | def references: Map[Symbol, List[Tree]] 29 | } 30 | object CompilationUnitIndex { 31 | 32 | 33 | private lazy val scalaVersion = { 34 | val Version = "version (\\d+)\\.(\\d+)\\.(\\d+).*".r 35 | scala.util.Properties.versionString match { 36 | case Version(fst, snd, trd) => (fst.toInt, snd.toInt, trd.toInt) 37 | } 38 | } 39 | 40 | def apply(tree: Tree): CompilationUnitIndex = { 41 | 42 | assertCurrentThreadIsPresentationCompiler() 43 | 44 | val defs = new HashMap[Symbol, ListBuffer[DefTree]] 45 | val refs = new HashMap[Symbol, ListBuffer[Tree]] 46 | 47 | def addDefinition(s: Symbol, t: DefTree): Unit = { 48 | def add(s: Symbol) = 49 | defs.getOrElseUpdate(s, new ListBuffer[DefTree]) += t 50 | 51 | def isLowerScalaVersionThan2_10_1 = { 52 | scalaVersion._2 < 10 || scalaVersion._2 == 10 && scalaVersion._3 == 0 53 | } 54 | 55 | t.symbol match { 56 | case ts: TermSymbol if ts.isLazy && isLowerScalaVersionThan2_10_1 => 57 | add(ts.lazyAccessor) 58 | case _ => 59 | add(s) 60 | } 61 | } 62 | 63 | def addReference(s: Symbol, t: Tree): Unit = { 64 | def add(s: Symbol) = 65 | refs.getOrElseUpdate(s, new ListBuffer[Tree]) += t 66 | 67 | add(s) 68 | 69 | s match { 70 | case _: ClassSymbol => () 71 | /* 72 | * If we only have a TypeSymbol, we check if it is 73 | * a reference to another symbol and add this to the 74 | * index as well. 75 | * 76 | * This is needed for example to find the TypeTree 77 | * of a DefDef parameter-ValDef 78 | * */ 79 | case ts: TypeSymbol => 80 | ts.info match { 81 | case tr: TypeRef if tr.sym != null && /*otherwise we get wrong matches because of Type-Aliases*/ 82 | tr.sym.nameString == s.nameString => 83 | add(tr.sym) 84 | case _ => () 85 | } 86 | case _ => () 87 | } 88 | } 89 | 90 | def handleSymbol(s: Symbol, t: Tree) = t match { 91 | case t: DefTree => addDefinition(s, t) 92 | case _ => addReference(s, t) 93 | } 94 | 95 | (new TreeWithSymbolTraverser(handleSymbol)).traverse(tree) 96 | 97 | new CompilationUnitIndex { 98 | val root = tree 99 | val definitions = defs.map{ case (sym, v) => sym.initialize → v.toList}.toMap 100 | val references = refs.map{ case (sym, v) => sym.initialize → v.toList}.toMap 101 | } 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/analysis/ImportsToolbox.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package analysis 3 | 4 | /** Class to wrap path dependent type on `CompilationUnitDependencies` used in `CompilationUnitDependencies`. */ 5 | class ImportsToolbox[C <: CompilationUnitDependencies with common.EnrichedTrees](val cuDependenciesInstance: C) { 6 | import cuDependenciesInstance.global._ 7 | 8 | def apply(tree: Tree) = new IsSelectNotInRelativeImports(tree) 9 | 10 | /** Checks if for given `Select` potentially done from `TypeTree` exists import (represented by `Import`) 11 | * in this `Select` scope or its parent. 12 | */ 13 | class IsSelectNotInRelativeImports(wholeTree: Tree) { 14 | private def collectPotentialOwners(of: Select): List[Symbol] = { 15 | var owners = List.empty[Symbol] 16 | def isSelectEmbracedByTree(tree: Tree): Boolean = 17 | tree.pos.isRange && tree.pos.start <= of.pos.start && of.pos.end <= tree.pos.end 18 | val collectPotentialOwners = new Traverser { 19 | var owns = List.empty[Symbol] 20 | override def traverse(t: Tree) = { 21 | owns = currentOwner :: owns 22 | t match { 23 | case potential if isSelectEmbracedByTree(potential) => 24 | owners = owns.distinct 25 | super.traverse(t) 26 | case t => 27 | super.traverse(t) 28 | owns = owns.tail 29 | } 30 | } 31 | } 32 | collectPotentialOwners.traverse(wholeTree) 33 | owners 34 | } 35 | 36 | /** Returns `true` if an import has been found for tested `Select` and `false` otherwise. 37 | * 38 | * Note: the examples below assume that `b` in `val baz: b` produces `TypeTree` which is 39 | * converted to `Select`. So examples are just a visualization of potential use case. 40 | * 41 | * Examples: 42 | * {{{ 43 | * trait A { 44 | * import a.b 45 | * def foo = { 46 | * val baz: b = ??? 47 | * } 48 | * } 49 | * }}} 50 | * returns `true` 51 | * {{{ 52 | * def foo = { 53 | * import a.b 54 | * val baz: b = ??? 55 | * } 56 | * }}} 57 | * returns `true` 58 | * but 59 | * {{{ 60 | * def foo = { 61 | * val baz: b = ??? 62 | * import a.b 63 | * } 64 | * }}} 65 | * returns `false` 66 | * For more see tests suites. 67 | */ 68 | def apply(tested: Select): Boolean = { 69 | val doesNameFitInTested = compareNameWith(tested) _ 70 | val nonPackageOwners = collectPotentialOwners(tested).filterNot { _.hasPackageFlag } 71 | def isValidPosition(t: Import): Boolean = t.pos.isRange && t.pos.start < tested.pos.start 72 | val isImportForTested = new Traverser { 73 | var found = false 74 | override def traverse(t: Tree) = t match { 75 | case imp: Import if isValidPosition(imp) && doesNameFitInTested(imp) && nonPackageOwners.contains(currentOwner) => 76 | found = true 77 | case t => super.traverse(t) 78 | } 79 | } 80 | isImportForTested.traverse(wholeTree) 81 | !isImportForTested.found 82 | } 83 | 84 | private def compareNameWith(tested: Select)(that: Import): Boolean = { 85 | import cuDependenciesInstance.additionalTreeMethodsForPositions 86 | def mkName(t: Tree) = if (t.symbol != null && t.symbol != NoSymbol) t.symbol.fullNameString else t.nameString 87 | val Select(testedQual, testedName) = tested 88 | val testedQName = List(mkName(testedQual), testedName).mkString(".") 89 | val Import(thatQual, thatSels) = that 90 | val impNames = thatSels.flatMap { sel => 91 | if (sel.name == nme.WILDCARD) List(mkName(thatQual)) 92 | else Set(sel.name, sel.rename).map { name => List(mkName(thatQual), name).mkString(".") }.toList 93 | } 94 | impNames.exists { testedQName.startsWith } 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/analysis/NameValidation.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package analysis 7 | 8 | import tools.nsc.ast.parser.Scanners 9 | import tools.nsc.ast.parser.Tokens 10 | import scala.util.control.NonFatal 11 | import scala.reflect.internal.util.BatchSourceFile 12 | 13 | /** 14 | * NameValidation offers several methods to validate 15 | * new names; depending on the context they are used. 16 | */ 17 | trait NameValidation { 18 | 19 | self: Indexes with common.Selections with common.CompilerAccess => 20 | 21 | import global._ 22 | 23 | /** 24 | * Returns true if this name is a valid identifier, 25 | * as accepted by the Scala compiler. 26 | */ 27 | def isValidIdentifier(name: String): Boolean = { 28 | 29 | val scanner = new { val global = self.global } with Scanners { 30 | val cu = new global.CompilationUnit(new BatchSourceFile("", name)) 31 | val scanner = new UnitScanner(cu) 32 | }.scanner 33 | 34 | try { 35 | scanner.init() 36 | val firstTokenIsIdentifier = Tokens.isIdentifier(scanner.token) 37 | 38 | scanner.nextToken() 39 | val secondTokenIsEOF = scanner.token == Tokens.EOF 40 | 41 | firstTokenIsIdentifier && secondTokenIsEOF 42 | } catch { 43 | case NonFatal(_) => false 44 | } 45 | } 46 | 47 | /** 48 | * Returns all symbols that might collide with the new name 49 | * at the given symbol's location. 50 | * 51 | * For example, if the symbol is a method, it is checked if 52 | * there already exists a method with this name in the full 53 | * class hierarchy of that method's class. 54 | * 55 | * The implemented checks are only an approximation and not 56 | * necessarily correct. 57 | */ 58 | def doesNameCollide(name: String, s: Symbol): List[Symbol] = { 59 | 60 | def isNameAlreadyUsedInLocalScope: List[Symbol] = { 61 | (index declaration s.owner map TreeSelection).toList flatMap { 62 | _.selectedSymbols.filter(_.nameString == name) 63 | } 64 | } 65 | 66 | def isNameAlreadyUsedInClassHierarchy = { 67 | index completeClassHierarchy s.owner flatMap (_.tpe.members) filter (_.nameString == name) 68 | } 69 | 70 | def isNameAlreadyUsedInPackageHierarchy = { 71 | index completePackageHierarchy s.owner flatMap (_.tpe.members) filter (_.nameString == name) 72 | } 73 | 74 | val owner = s.owner 75 | 76 | if(s.isPrivate || s.isLocal) { 77 | isNameAlreadyUsedInLocalScope.distinct 78 | } else if(owner.isClass && !(owner.isModuleClass || owner.isClass && nme.isLocalName(owner.name))) { 79 | isNameAlreadyUsedInClassHierarchy.distinct 80 | } else { 81 | isNameAlreadyUsedInPackageHierarchy.distinct 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/analysis/TreeAnalysis.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package analysis 7 | 8 | import common.Selections 9 | 10 | /** 11 | * Provides some simple methods to analyze the program's 12 | * data flow, as used by Extract Method to find in and out 13 | * parameters. 14 | */ 15 | trait TreeAnalysis { 16 | 17 | this: Selections with Indexes with common.CompilerAccess => 18 | 19 | /** 20 | * From the selection and in the scope of the currentOwner, returns 21 | * a list of all symbols that are owned by currentOwner and used inside 22 | * but declared outside the selection. 23 | */ 24 | @deprecated("use selection.inbondLocalDeps instead", "0.6") 25 | def inboundLocalDependencies(selection: Selection, currentOwner: global.Symbol): List[global.Symbol] = { 26 | 27 | val allLocalSymbols = selection.selectedSymbols filter { 28 | _.ownerChain.contains(currentOwner) 29 | } 30 | 31 | allLocalSymbols.filterNot { 32 | index.declaration(_).map(selection.contains) getOrElse true 33 | }.filter(t => t.pos.isOpaqueRange).sortBy(_.pos.start).distinct 34 | } 35 | 36 | /** 37 | * From the selection and in the scope of the currentOwner, returns 38 | * a list of all symbols that are defined inside the selection and 39 | * used outside of it. 40 | */ 41 | @deprecated("use selection.outboundLocalDeps instead", "0.6") 42 | def outboundLocalDependencies(selection: Selection): List[global.Symbol] = { 43 | 44 | val declarationsInTheSelection = selection.selectedSymbols filter (s => index.declaration(s).map(selection.contains) getOrElse false) 45 | 46 | val occurencesOfSelectedDeclarations = declarationsInTheSelection flatMap (index.occurences) 47 | 48 | occurencesOfSelectedDeclarations.filterNot(selection.contains).map(_.symbol).distinct 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/Change.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package common 7 | 8 | import scala.tools.nsc.io.AbstractFile 9 | import scala.reflect.internal.util.SourceFile 10 | 11 | /** 12 | * The common interface for all changes. 13 | */ 14 | sealed trait Change 15 | 16 | case class TextChange(sourceFile: SourceFile, from: Int, to: Int, text: String) extends Change { 17 | 18 | def file = sourceFile.file 19 | 20 | /** 21 | * Instead of a change to an existing file, return a change that creates a new file 22 | * with the change applied to the original file. 23 | * 24 | * @param fullNewName The fully qualified package name of the target. 25 | */ 26 | def toNewFile(fullNewName: String) = { 27 | val src = Change.applyChanges(List(this), new String(sourceFile.content)) 28 | NewFileChange(fullNewName, src) 29 | } 30 | } 31 | 32 | /** 33 | * The changes creates a new source file, indicated by the `fullName` parameter. It is of 34 | * the form "some.package.FileName". 35 | */ 36 | case class NewFileChange(fullName: String, text: String) extends Change 37 | 38 | case class MoveToDirChange(sourceFile: AbstractFile, to: String) extends Change 39 | 40 | case class RenameSourceFileChange(sourceFile: AbstractFile, to: String) extends Change 41 | 42 | object Change { 43 | /** 44 | * Applies the list of changes to the source string. NewFileChanges are ignored. 45 | * Primarily used for testing / debugging. 46 | */ 47 | def applyChanges(ch: List[Change], source: String): String = { 48 | val changes = ch collect { 49 | case tc: TextChange => tc 50 | } 51 | 52 | val sortedChanges = changes.sortBy { descendingTo } 53 | 54 | /* Test if there are any overlapping text edits. This is 55 | not necessarily an error, but Eclipse doesn't allow 56 | overlapping text edits, and this helps us catch them 57 | in our own tests. */ 58 | sortedChanges.sliding(2).toList foreach { 59 | case List(TextChange(_, from, _, _), TextChange(_, _, to, _)) => 60 | assert(from >= to) 61 | case _ => () 62 | } 63 | 64 | (source /: sortedChanges) { (src, change) => 65 | src.take(change.from) + change.text + src.drop(change.to) 66 | } 67 | } 68 | 69 | private def descendingTo(change: TextChange) = -change.to 70 | 71 | case class AcceptReject(accepted: List[Change], rejected: List[Change]) 72 | 73 | def discardOverlappingChanges(changes: List[Change]): AcceptReject = { 74 | val applicableChanges = changes.collect { 75 | case tc: TextChange => tc 76 | }.sortBy { descendingTo } 77 | applicableChanges.foldLeft(AcceptReject(Nil, Nil)) { (acc, ch) => acc.accepted match { 78 | case Nil => acc.copy(accepted = ch :: acc.accepted) 79 | case (h: TextChange) :: _ if h.from >= ch.to => acc.copy(accepted = ch :: acc.accepted) 80 | case _ => acc.copy(rejected = ch :: acc.rejected) 81 | } } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/CompilerAccess.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package common 7 | 8 | import tools.nsc.io.AbstractFile 9 | 10 | trait CompilerAccess { 11 | 12 | val global: tools.nsc.Global 13 | 14 | def compilationUnitOfFile(f: AbstractFile): Option[global.CompilationUnit] 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.common 2 | 3 | /* 4 | * FIXME: This class duplicates functionality from org.scalaide.core.compiler.CompilerApiExtensions. 5 | */ 6 | trait CompilerApiExtensions { 7 | this: CompilerAccess => 8 | import global._ 9 | 10 | /** Locate the smallest tree that encloses position. 11 | * 12 | * @param tree The tree in which to search `pos` 13 | * @param pos The position to look for 14 | * @param p An additional condition to be satisfied by the resulting tree 15 | * @return The innermost enclosing tree for which p is true, or `EmptyTree` 16 | * if the position could not be found. 17 | */ 18 | def locateIn(tree: Tree, pos: Position, p: Tree => Boolean = t => true): Tree = 19 | new FilteringLocator(pos, p) locateIn tree 20 | 21 | def enclosingPackage(tree: Tree, pos: Position): Tree = { 22 | locateIn(tree, pos, _.isInstanceOf[PackageDef]) 23 | } 24 | 25 | private class FilteringLocator(pos: Position, p: Tree => Boolean) extends Locator(pos) { 26 | override def isEligible(t: Tree) = super.isEligible(t) && p(t) 27 | } 28 | 29 | /* 30 | * For Scala-2.10 (see scala.reflect.internal.Positions.Locator in Scala-2.11). 31 | */ 32 | private class Locator(pos: Position) extends Traverser { 33 | var last: Tree = _ 34 | def locateIn(root: Tree): Tree = { 35 | this.last = EmptyTree 36 | traverse(root) 37 | this.last 38 | } 39 | protected def isEligible(t: Tree) = !t.pos.isTransparent 40 | override def traverse(t: Tree): Unit = { 41 | t match { 42 | case tt : TypeTree if tt.original != null && (tt.pos includes tt.original.pos) => 43 | traverse(tt.original) 44 | case _ => 45 | if (t.pos includes pos) { 46 | if (isEligible(t)) last = t 47 | super.traverse(t) 48 | } else t match { 49 | case mdef: MemberDef => 50 | traverseTrees(mdef.mods.annotations) 51 | case _ => 52 | } 53 | } 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/InteractiveScalaCompiler.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package common 7 | 8 | import tools.nsc.io.AbstractFile 9 | import scala.reflect.internal.util.SourceFile 10 | 11 | /** 12 | * Many parts of the library can work with the non-interactive global, 13 | * but some -- most notably the refactoring implementations -- need an 14 | * interactive compiler, which is expressed by this trait. 15 | */ 16 | trait InteractiveScalaCompiler extends CompilerAccess { 17 | 18 | val global: tools.nsc.interactive.Global 19 | 20 | def compilationUnitOfFile(f: AbstractFile) = global.unitOfFile.get(f) 21 | 22 | /** 23 | * Returns a fully loaded and typed Tree instance for the given SourceFile. 24 | */ 25 | def askLoadedAndTypedTreeForFile(file: SourceFile): Either[global.Tree, Throwable] = { 26 | val r = new global.Response[global.Tree] 27 | global.askLoadedTyped(file, r) 28 | r.get 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/Occurrences.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.common 2 | 3 | import scala.tools.refactoring.analysis.Indexes 4 | import scala.reflect.internal.util.RangePosition 5 | import scala.reflect.internal.util.OffsetPosition 6 | 7 | /** 8 | * Provides functionalities to get positions of term names. This includes the term name 9 | * defintion and all its uses. 10 | */ 11 | trait Occurrences extends Selections with CompilerAccess with Indexes { 12 | import global._ 13 | 14 | type Occurrence = (Int, Int) 15 | 16 | private def termNameDefinition(root: Tree, name: String) = { 17 | root.collect { 18 | case t: DefTree if t.name.decode == name => 19 | t 20 | }.headOption 21 | } 22 | 23 | private def defToOccurrence(t: DefTree) = t.namePosition() match { 24 | case p: RangePosition => 25 | (p.start, p.end - p.start) 26 | case p: OffsetPosition => 27 | (p.point, t.name.decode.length) 28 | } 29 | 30 | private def refToOccurrence(t: RefTree) = t.pos match { 31 | case p: RangePosition => 32 | (p.start, p.end - p.start) 33 | case p: OffsetPosition => 34 | (p.point, t.symbol.name.decode.length) 35 | } 36 | 37 | /** 38 | * Returns all uses of the term name introduced by the DefTree t. 39 | */ 40 | def allOccurrences(t: DefTree): List[Occurrence] = { 41 | val refOccurrences = index.references(t.symbol).collect { 42 | case ref: RefTree => refToOccurrence(ref) 43 | } 44 | defToOccurrence(t) :: refOccurrences 45 | } 46 | 47 | /** 48 | * Searches for a definition of `name` in `root` and returns is's position 49 | * and all positions of references to the definition. 50 | * Returns an empty list if the definition of `name` is not in `selection`. 51 | */ 52 | def termNameOccurrences(root: Tree, name: String): List[Occurrence] = { 53 | termNameDefinition(root, name) match { 54 | case Some(t) => allOccurrences(t) 55 | case None => Nil 56 | } 57 | } 58 | 59 | /** 60 | * Searches for a method definition of `defName` in `root` and returns for each 61 | * method parameter a list with the position of the parameter definition and 62 | * all occurrences. 63 | * Returns an empty list if `defName` is not found, defines something that is not 64 | * a method or if the method has no parameters. 65 | */ 66 | def defDefParameterOccurrences(root: Tree, defName: String): List[List[Occurrence]] = { 67 | termNameDefinition(root, defName) match { 68 | case Some(DefDef(_, _, _, params, _, _)) => 69 | params.flatten.map { p => allOccurrences(p) } 70 | case _ => Nil 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/PositionDebugging.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.common 2 | 3 | import scala.reflect.api.Position 4 | import scala.reflect.internal.util.NoPosition 5 | import scala.reflect.internal.util.NoSourceFile 6 | import scala.tools.refactoring.getSimpleClassName 7 | 8 | /** 9 | * Some utilities for debugging purposes. 10 | */ 11 | object PositionDebugging { 12 | def format(pos: Position): String = { 13 | formatInternal(pos, false) 14 | } 15 | 16 | def formatCompact(pos: Position): String = { 17 | formatInternal(pos, true) 18 | } 19 | 20 | def format(start: Int, end: Int, source: Array[Char]): String = { 21 | def slice(start: Int, end: Int): String = { 22 | source.view(start, end).mkString("").replace("\r\n", "\\r\\n").replace("\n", "\\n") 23 | } 24 | 25 | val ctxChars = 10 26 | val l = slice(start - ctxChars, start) 27 | val m = slice(start, end) 28 | val r = slice(end, end + ctxChars) 29 | s"$l«$m»$r".trim 30 | } 31 | 32 | private def formatInternal(pos: Position, compact: Boolean): String = { 33 | if (pos != NoPosition && pos.source != NoSourceFile) { 34 | val posType = getSimpleClassName(pos) 35 | 36 | val (start, point, end) = { 37 | if (!pos.isRange) (pos.point, pos.point, pos.point) 38 | else (pos.start, pos.point, pos.end) 39 | } 40 | 41 | val markerString = { 42 | if (start == end) s"($start)" 43 | else s"($start, $point, $end)" 44 | } 45 | 46 | val relevantSource = { 47 | if (compact) "" 48 | else "[" + format(start, end, pos.source.content) + "]" 49 | } 50 | 51 | s"$posType$markerString$relevantSource" 52 | } else { 53 | "UndefinedPosition" 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/TracingHelpers.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.common 2 | 3 | object TracingHelpers { 4 | def compactify(text: String): String = { 5 | val lines = text.lines.toArray 6 | val firstLine = lines.headOption.getOrElse("") 7 | val snipAfter = 50 8 | 9 | val (compactedFirstLine, dotsAdded) = { 10 | if (firstLine.size <= snipAfter) { 11 | (firstLine, false) 12 | } else { 13 | (firstLine.substring(0, snipAfter) + "...", true) 14 | } 15 | } 16 | 17 | if (lines.size <= 1) { 18 | compactedFirstLine 19 | } else { 20 | val compactedFirstLineWithDots = { 21 | if (dotsAdded) compactedFirstLine 22 | else compactedFirstLine + "..." 23 | } 24 | 25 | val moreLines = lines.size - 1 26 | compactedFirstLineWithDots + s"($moreLines more lines ommitted)" 27 | } 28 | } 29 | 30 | def toCompactString(any: Any): String = compactify("" + any) 31 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/TreeExtractors.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package common 7 | 8 | import PartialFunction.cond 9 | 10 | trait TreeExtractors { 11 | 12 | this: common.CompilerAccess => 13 | 14 | import global._ 15 | 16 | object Names { 17 | lazy val scala = newTermName("scala") 18 | lazy val pkg = newTermName("package") 19 | lazy val None = newTermName("None") 20 | lazy val Some = newTermName("Some") 21 | lazy val Predef = newTermName("Predef") 22 | lazy val reflect = newTermName("reflect") 23 | lazy val apply = newTermName("apply") 24 | lazy val Nil = newTermName("Nil") 25 | lazy val immutable = newTypeName("immutable") 26 | lazy val :: = newTermName("$colon$colon") 27 | lazy val List = newTermName("List") 28 | lazy val Seq = newTermName("Seq") 29 | lazy val collection = newTermName("collection") 30 | lazy val immutableTerm = newTermName("immutable") 31 | lazy val scalaType = newTypeName("scala") 32 | } 33 | 34 | /** 35 | * An extractor for the Some constructor. 36 | */ 37 | object SomeExpr { 38 | def unapply(t: Tree): Option[Tree] = t match { 39 | case Apply(TypeApply(Select(Select(Ident(Names.scala), Names.Some), _), (_: TypeTree) :: Nil), argToSome :: Nil) => Some(argToSome) 40 | case _ => None 41 | } 42 | } 43 | 44 | /** 45 | * A boolean extractor for the None constructor. 46 | */ 47 | object NoneExpr { 48 | def unapply(t: Tree) = cond(t) { 49 | case Select(Ident(Names.scala), Names.None) => true 50 | } 51 | } 52 | 53 | /** 54 | * An extractor for the List constructor `List` or :: 55 | */ 56 | object ListExpr { 57 | def unapply(t: Tree): Option[Tree] = t match { 58 | case Block( 59 | (ValDef(_, v1, _, arg)) :: Nil, Apply(TypeApply(Select(Select(This(Names.immutable), Names.Nil), Names.::), (_: TypeTree) :: Nil), 60 | Ident(v2) :: Nil)) if v1 == v2 => 61 | Some(arg) 62 | case Apply(TypeApply(Select(Select(This(Names.immutable), Names.List), Names.apply), (_: TypeTree) :: Nil), arg :: Nil) => 63 | Some(arg) 64 | case _ => 65 | None 66 | } 67 | } 68 | 69 | /** 70 | * A boolean extractor for the Nil object. 71 | */ 72 | object NilExpr { 73 | def unapply(t: Tree) = cond(t) { 74 | case Select(This(Names.immutable), Names.Nil) => true 75 | } 76 | } 77 | 78 | /** 79 | * An extractor for the () literal tree. 80 | */ 81 | object UnitLit { 82 | def unapply(t: Tree) = cond(t) { 83 | case Literal(c) => c.tag == UnitTag 84 | } 85 | } 86 | 87 | /** 88 | * An extractor that returns the name of a tree's 89 | * type as a String. 90 | */ 91 | object HasType { 92 | 93 | def getTypeName(t: Type): Option[String] = t match { 94 | case TypeRef(_, sym, _) => 95 | Some(sym.nameString) 96 | case ConstantType(value) => 97 | getTypeName(value.tpe) 98 | case _ => None 99 | } 100 | 101 | def unapply(t: Tree) = { 102 | getTypeName(t.tpe) 103 | } 104 | } 105 | 106 | /** 107 | * True if the tree's type is Unit 108 | */ 109 | def hasUnitType(t: Tree) = t.tpe match { 110 | case TypeRef(_, sym, _) => sym == definitions.UnitClass 111 | case _ => false 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/exceptions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package common 7 | 8 | class TreeNotFound(file: String) extends Exception("Tree not found for file "+ file +".") 9 | 10 | class RefactoringError(cause: String) extends Exception(cause) -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/package.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | 3 | package object common { 4 | /** 5 | * The selected tracing implementation. 6 | * 7 | * Use [[SilentTracing]] for production; consider [[DebugTracing]] for debugging. 8 | */ 9 | type TracingImpl = SilentTracing 10 | } 11 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/common/tracing.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package common 7 | 8 | import java.io.File 9 | import java.io.FileOutputStream 10 | import java.io.PrintStream 11 | 12 | trait Tracing { 13 | protected abstract class TraceAndReturn[T] { 14 | def \\(trace: T => Unit): T 15 | } 16 | 17 | protected implicit def wrapInTraceAndReturn[T](t: T): TraceAndReturn[T] 18 | 19 | def context[T](name: String)(body: => T): T 20 | 21 | def trace(msg: => String, arg1: => Any, args: Any*): Unit 22 | 23 | def trace(msg: => String): Unit 24 | } 25 | 26 | object DebugTracing { 27 | private val debugStream = { 28 | Option(System.getProperty("scala.refactoring.traceFile")).flatMap { fn => 29 | try { 30 | val traceFile = new File(fn) 31 | val out = new FileOutputStream(traceFile, true) 32 | Some(new PrintStream(out, true, "UTF-8")) 33 | } catch { 34 | case e: Exception => 35 | e.printStackTrace() 36 | System.err.println(s"Could not open '$fn' for writing; falling back to System.out...") 37 | None 38 | } 39 | }.getOrElse(System.out) 40 | } 41 | 42 | private def printLine(str: String) = { 43 | debugStream.println(str) 44 | } 45 | } 46 | 47 | /** 48 | * Traces to STDOUT or a custom file (via the system property `scala.refactoring.traceFile`) 49 | */ 50 | trait DebugTracing extends Tracing { 51 | import DebugTracing._ 52 | 53 | var level = 0 54 | val marker = "│" 55 | val indent = " " 56 | 57 | override def context[T](name: String)(body: => T): T = { 58 | 59 | val spacer = "─" * (indent.length - 1) 60 | 61 | printLine((indent * level) +"╰"+ spacer +"┬────────" ) 62 | level += 1 63 | trace("→ "+ name) 64 | 65 | body \\ { _ => 66 | level -= 1 67 | printLine((indent * level) + "╭"+ spacer +"┴────────" ) 68 | } 69 | } 70 | 71 | override def trace(msg: => String, arg1: => Any, args: Any*): Unit = { 72 | 73 | val as: Array[AnyRef] = arg1 +: args.toArray map { 74 | case s: String => "«"+ s.replaceAll("\n", "\\\\n") +"»" 75 | case a: AnyRef => a 76 | } 77 | 78 | trace(msg.format(as: _*)) 79 | } 80 | 81 | override def trace(msg: => String): Unit = { 82 | val border = (indent * level) + marker 83 | printLine(border + msg.replaceAll("\n", "\n"+ border)) 84 | } 85 | 86 | protected implicit final def wrapInTraceAndReturn[T](t: T) = new TraceAndReturn[T] { 87 | def \\(trace: T => Unit) = { 88 | trace(t) 89 | t 90 | } 91 | } 92 | } 93 | 94 | trait SilentTracing extends Tracing { 95 | @inline 96 | def trace(msg: => String, arg1: => Any, args: Any*) = () 97 | 98 | @inline 99 | def trace(msg: => String) = () 100 | 101 | @inline 102 | def context[T](name: String)(body: => T): T = body 103 | 104 | protected implicit final def wrapInTraceAndReturn[T](t: T) = new TraceAndReturn[T] { 105 | def \\(ignored: T => Unit) = t 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/AddField.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | import scala.tools.nsc.io.AbstractFile 5 | import scala.tools.refactoring.common.TextChange 6 | import scala.tools.nsc.ast.parser.Tokens 7 | 8 | abstract class AddField extends AddValOrDef { 9 | 10 | val global: tools.nsc.interactive.Global 11 | import global._ 12 | 13 | def addField(file: AbstractFile, className: String, valName: String, isVar: Boolean, returnTypeOpt: Option[String], target: AddMethodTarget): List[TextChange] = 14 | addValOrDef(file, className, target, addField(valName, isVar, returnTypeOpt, _)) 15 | 16 | private def addField(valName: String, isVar: Boolean, returnTypeOpt: Option[String], classOrObjectDef: Tree): List[TextChange] = { 17 | val returnStatement = Ident("???") 18 | 19 | val returnType = returnTypeOpt.map(name => TypeTree(newType(name))).getOrElse(new TypeTree) 20 | 21 | val mods = if (isVar) Modifiers(Flag.MUTABLE) withPosition (Tokens.VAR, NoPosition) else NoMods 22 | 23 | val newVal = mkValOrVarDef(mods, valName, returnStatement, returnType) 24 | 25 | val insertField = insertDef(newVal) 26 | 27 | refactor((insertField apply classOrObjectDef).toList) 28 | } 29 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/AddImportStatement.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package implementations 7 | 8 | import common.InteractiveScalaCompiler 9 | import common.Change 10 | import scala.tools.nsc.io.AbstractFile 11 | import scala.tools.refactoring.common.TextChange 12 | 13 | abstract class AddImportStatement extends Refactoring with InteractiveScalaCompiler { 14 | 15 | override val global: tools.nsc.interactive.Global 16 | 17 | def addImport(file: AbstractFile, fqName: String): List[TextChange] = addImports(file, List(fqName)) 18 | 19 | def addImports(file: AbstractFile, importsToAdd: Iterable[String]): List[TextChange] = { 20 | 21 | val astRoot = abstractFileToTree(file) 22 | 23 | addImportTransformation(importsToAdd.toSeq)(astRoot).toList 24 | } 25 | 26 | @deprecated("Use addImport(file, ..) instead", "0.4.0") 27 | def addImport(selection: Selection, fullyQualifiedName: String): List[Change] = { 28 | addImport(selection.file, fullyQualifiedName) 29 | } 30 | 31 | @deprecated("Use addImport(file, ..) instead", "0.4.0") 32 | def addImport(selection: Selection, pkg: String, name: String): List[Change] = { 33 | addImport(selection.file, pkg +"."+ name) 34 | } 35 | 36 | @deprecated("Not needed anymore, don't override.", "0.4.0") 37 | def getContentForFile(file: AbstractFile): Array[Char] = throw new UnsupportedOperationException 38 | } 39 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/AddMethod.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | import scala.tools.nsc.io.AbstractFile 5 | import scala.tools.refactoring.common.TextChange 6 | 7 | abstract class AddMethod extends AddValOrDef { 8 | 9 | val global: tools.nsc.interactive.Global 10 | import global._ 11 | 12 | def addMethod(file: AbstractFile, className: String, methodName: String, parameters: List[List[(String, String)]], typeParameters: List[String], returnType: Option[String], target: AddMethodTarget): List[TextChange] = 13 | addValOrDef(file, className, target, addMethod(methodName, parameters, typeParameters, returnType, _)) 14 | 15 | def addMethod(file: AbstractFile, className: String, methodName: String, parameters: List[List[(String, String)]], returnType: Option[String], target: AddMethodTarget): List[TextChange] = 16 | addMethod(file, className, methodName, parameters, Nil, returnType, target) 17 | 18 | private def addMethod(methodName: String, parameters: List[List[(String, String)]], typeParameters: List[String], returnTypeOpt: Option[String], classOrObjectDef: Tree): List[TextChange] = { 19 | val nscParameters = for (paramList <- parameters) yield for ((paramName, typeName) <- paramList) yield { 20 | val paramSymbol = NoSymbol.newValue(newTermName(paramName)) 21 | paramSymbol.setInfo(newType(typeName)) 22 | paramSymbol 23 | } 24 | 25 | val typeParams = if (typeParameters.nonEmpty) { 26 | val typeDef = new TypeDef(NoMods, newTypeName(typeParameters.mkString(", ")), Nil, EmptyTree) 27 | List(typeDef) 28 | } else Nil 29 | 30 | val returnStatement = Ident("???") :: Nil 31 | 32 | val newDef = mkDefDef(NoMods, methodName, nscParameters, returnStatement, typeParams, returnTypeOpt = returnTypeOpt.map(name => TypeTree(newType(name)))) 33 | 34 | val insertMethod = insertDef(newDef) 35 | 36 | refactor((insertMethod apply classOrObjectDef).toList) 37 | } 38 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/AddValOrDef.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | import scala.tools.nsc.io.AbstractFile 5 | import scala.tools.refactoring.Refactoring 6 | import scala.tools.refactoring.common.TextChange 7 | 8 | import common.InteractiveScalaCompiler 9 | 10 | trait AddValOrDef extends Refactoring with InteractiveScalaCompiler { 11 | 12 | val global: tools.nsc.interactive.Global 13 | import global._ 14 | 15 | def addValOrDef(file: AbstractFile, className: String, target: AddMethodTarget, changeFunc: Tree => List[TextChange]): List[TextChange] = { 16 | val astRoot = abstractFileToTree(file) 17 | 18 | //it would be nice to pass in the symbol and use that rather than compare the name, but it might not be available 19 | val classOrObjectDef = target match { 20 | case AddToClosest(offset: Int) => { 21 | case class UnknownDef(tree: Tree, offset: Int) 22 | 23 | val classAndObjectDefs = astRoot.collect { 24 | case classDef: ClassDef if classDef.name.decode == className => 25 | UnknownDef(classDef, classDef.namePosition.point) 26 | case moduleDef: ModuleDef if moduleDef.name.decode == className => 27 | UnknownDef(moduleDef, moduleDef.namePosition.point) 28 | } 29 | 30 | //the class/object definition just before the given offset 31 | classAndObjectDefs.sortBy(_.offset).reverse.find(_.offset < offset).map(_.tree) 32 | } 33 | case _ => { 34 | astRoot.find { 35 | case classDef: ClassDef if target == AddToClass => classDef.name.decode == className 36 | case moduleDef: ModuleDef if target == AddToObject => moduleDef.name.decode == className 37 | case _ => false 38 | } 39 | } 40 | } 41 | 42 | changeFunc(classOrObjectDef.get) 43 | } 44 | 45 | protected def insertDef(valOrDef: ValOrDefDef) = { 46 | def addMethodToTemplate(tpl: Template) = tpl copy (body = tpl.body ::: valOrDef :: Nil) replaces tpl 47 | 48 | transform { 49 | case implDef: ImplDef => addMethodToTemplate(implDef.impl) 50 | } 51 | } 52 | 53 | protected def newType(name: String) = new Type { 54 | override def safeToString: String = name 55 | } 56 | } 57 | 58 | sealed trait AddMethodTarget 59 | case object AddToClass extends AddMethodTarget 60 | case object AddToObject extends AddMethodTarget 61 | case class AddToClosest(offset: Int) extends AddMethodTarget -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/ChangeParamOrder.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | 5 | /** 6 | * Refactoring that changes the order of the parameters of a method. 7 | */ 8 | abstract class ChangeParamOrder extends MethodSignatureRefactoring { 9 | 10 | import global._ 11 | 12 | type Permutation = List[Int] 13 | /** 14 | * There has to be a permutation for each parameter list of the selected method. 15 | */ 16 | type RefactoringParameters = List[Permutation] 17 | 18 | override def checkRefactoringParams(prep: PreparationResult, affectedDefs: AffectedDefs, params: RefactoringParameters) = 19 | (prep.defdef.vparamss corresponds params) ((vp, p) => (0 until vp.length).toList == (p.sortWith(_ < _))) 20 | 21 | def reorder[T](origVparamss: List[List[T]], permutations: List[Permutation]): List[List[T]] = 22 | (origVparamss zip permutations) map { 23 | case (params, perm) => reorderSingleParamList(params, perm) 24 | } 25 | 26 | def reorderSingleParamList[T](origVparams: List[T], permutation: Permutation) = 27 | permutation map origVparams 28 | 29 | override def defdefRefactoring(parameters: RefactoringParameters) = transform { 30 | case orig @ DefDef(mods, name, tparams, vparams, tpt, rhs) => { 31 | val reorderedVparams = reorder(vparams, parameters) 32 | DefDef(mods, name, tparams, reorderedVparams, tpt, rhs) replaces orig 33 | } 34 | } 35 | 36 | override def applyRefactoring(params: RefactoringParameters) = transform { 37 | case orig @ Apply(fun, args) => { 38 | val pos = paramListPos(findOriginalTree(orig)) - 1 39 | val reorderedArgs = reorderSingleParamList(args, params(pos)) 40 | Apply(fun, reorderedArgs) replaces orig 41 | } 42 | } 43 | 44 | override def prepareParamsForSingleRefactoring(originalParams: RefactoringParameters, selectedMethod: DefDef, toRefactor: DefInfo): RefactoringParameters = { 45 | val toDrop = originalParams.size - toRefactor.nrParamLists 46 | val touchablesPrepared = originalParams.drop(toDrop) 47 | val nrUntouchables = toRefactor.nrUntouchableParamLists 48 | val ids = originalParams.drop(toDrop - nrUntouchables).take(nrUntouchables) map { perm => 49 | (0 until perm.size).toList 50 | } 51 | ids match { 52 | case Nil => touchablesPrepared 53 | case _ => ids:::originalParams.drop(toDrop) 54 | } 55 | } 56 | 57 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/ClassParameterDrivenSourceGeneration.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | import scala.tools.refactoring.MultiStageRefactoring 5 | 6 | import common.Change 7 | 8 | /** 9 | * Baseclass for refactorings that generate class-level source based on 10 | * the class parameters. 11 | */ 12 | abstract class ClassParameterDrivenSourceGeneration extends MultiStageRefactoring with common.InteractiveScalaCompiler { 13 | 14 | import global._ 15 | 16 | case class PreparationResult( 17 | classDef: ClassDef, 18 | classParams: List[(ValDef, Boolean)], 19 | existingEqualityMethods: List[ValOrDefDef]) 20 | 21 | /** A function that takes a class parameter name and decides 22 | * whether this parameter should be used in equals/hashCode 23 | * computations, a boolean that indicates whether calls 24 | * to super should be used or not, and a boolean that indicates 25 | * whether existing equality methods (equals, canEqual and hashCode) 26 | * should be kept or replaced. 27 | */ 28 | case class RefactoringParameters( 29 | callSuper: Boolean = true, 30 | paramsFilter: ValDef => Boolean, 31 | keepExistingEqualityMethods: Boolean) 32 | 33 | def prepare(s: Selection) = { 34 | val notAClass = Left(PreparationError("No class definition selected.")) 35 | s.findSelectedOfType[ClassDef] match { 36 | case None => notAClass 37 | case Some(classDef) if classDef.hasSymbol && classDef.symbol.isTrait => notAClass 38 | case Some(classDef) => { 39 | failsBecause(classDef).map(PreparationError(_)) toLeft { 40 | val classParams = classDef.impl.nonPrivateClassParameters 41 | val equalityMethods = classDef.impl.existingEqualityMethods 42 | PreparationResult(classDef, classParams, equalityMethods) 43 | } 44 | } 45 | } 46 | } 47 | 48 | def failsBecause(classDef: ClassDef): Option[String] 49 | 50 | def perform(selection: Selection, prep: PreparationResult, params: RefactoringParameters): Either[RefactoringError, List[Change]] = { 51 | val selectedParams = prep.classParams.map(_._1) filter params.paramsFilter 52 | 53 | val templateFilter = filter { 54 | case prep.classDef.impl => true 55 | } 56 | 57 | val refactoring = topdown { 58 | matchingChildren { 59 | templateFilter &> sourceGeneration(selectedParams, prep, params) 60 | } 61 | } 62 | 63 | Right(transformFile(selection.file, refactoring)) 64 | } 65 | 66 | def sourceGeneration(params: List[ValDef], preparationResult: PreparationResult, refactoringParams: RefactoringParameters): Transformation[Tree, Tree] 67 | 68 | } 69 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/ExpandCaseClassBinding.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package implementations 7 | 8 | import common.Change 9 | import scala.tools.refactoring.analysis.GlobalIndexes 10 | 11 | abstract class ExpandCaseClassBinding extends MultiStageRefactoring with ParameterlessRefactoring with GlobalIndexes { 12 | 13 | this: common.CompilerAccess => 14 | 15 | import global._ 16 | 17 | case class PreparationResult(bind: Bind, sym: ClassSymbol, body: Tree) 18 | 19 | def prepare(s: Selection): Either[PreparationError, PreparationResult] = { 20 | 21 | def failure = Left(PreparationError("No binding to expand found. Please select a binding in a case clause.")) 22 | 23 | val res = s.findSelectedOfType[CaseDef] flatMap { caseDef => 24 | s.findSelectedOfType[Bind] map { 25 | case bind @ Bind(_, body) => 26 | body.tpe match { 27 | case TypeRef(_, sym: ClassSymbol, _) if sym.isCaseClass => 28 | Right(PreparationResult(bind, sym, caseDef.body)) 29 | case _ => failure 30 | } 31 | } 32 | } 33 | 34 | res getOrElse failure 35 | } 36 | 37 | def perform(selection: Selection, preparationResult: PreparationResult): Either[RefactoringError, List[Change]] = { 38 | 39 | val PreparationResult(bind, sym, body) = preparationResult 40 | 41 | val apply = { 42 | 43 | val argSymbols = sym.info.decls.toList filter (s => s.isCaseAccessor && s.isMethod) 44 | 45 | val argIdents = argSymbols map (s => Ident(s.name)) 46 | val isTuple = sym.nameString.matches("Tuple\\d+") 47 | 48 | if(isTuple) { 49 | // don't insert TupleX in the code 50 | Apply(Ident(""), argIdents) 51 | } else { 52 | Apply(Ident(sym.name), argIdents) 53 | } 54 | } 55 | 56 | val replacement = { 57 | // we create a mini-index of the casedef-body 58 | val index = GlobalIndex(body) 59 | val nameIsNotReferenced = index.references(bind.symbol).isEmpty 60 | 61 | if(nameIsNotReferenced) { 62 | apply 63 | } else { 64 | // creates a `name @ ` in the code 65 | bind copy (body = apply) 66 | } 67 | } replaces bind 68 | 69 | Right(refactor(List(replacement))) 70 | } 71 | 72 | def index: IndexLookup = sys.error("") 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/ExplicitGettersSetters.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package implementations 7 | 8 | import common.Change 9 | import scala.tools.nsc.ast.parser.Tokens 10 | import scala.tools.nsc.symtab.Flags 11 | 12 | abstract class ExplicitGettersSetters extends MultiStageRefactoring with ParameterlessRefactoring with common.InteractiveScalaCompiler { 13 | 14 | import global._ 15 | 16 | type PreparationResult = ValDef 17 | 18 | def prepare(s: Selection) = { 19 | s.findSelectedOfType[ValDef] match { 20 | case Some(valdef) => Right(valdef) 21 | case None => Left(new PreparationError("no valdef selected")) 22 | } 23 | } 24 | 25 | override def perform(selection: Selection, selectedValue: PreparationResult): Either[RefactoringError, List[Change]] = { 26 | 27 | val template = selection.findSelectedOfType[Template].getOrElse { 28 | return Left(RefactoringError("no template found")) 29 | } 30 | 31 | val createSetter = selectedValue.symbol.isMutable 32 | 33 | val publicName = selectedValue.name.toString.trim 34 | 35 | val privateName = "_"+ publicName 36 | 37 | val privateFieldMods = if(createSetter) 38 | Modifiers(Flags.PARAMACCESSOR). 39 | withPosition (Flags.PRIVATE, NoPosition). 40 | withPosition (Tokens.VAR, NoPosition) 41 | else 42 | Modifiers(Flags.PARAMACCESSOR) 43 | 44 | val privateField = selectedValue copy (mods = privateFieldMods, name = newTermName(privateName)) 45 | 46 | val getter = DefDef( 47 | mods = Modifiers(Flags.METHOD) withPosition (Flags.METHOD, NoPosition), 48 | name = newTermName(publicName), 49 | tparams = Nil, 50 | vparamss = Nil, 51 | tpt = EmptyTree, 52 | rhs = Block( 53 | Ident(privateName) :: Nil, EmptyTree)) 54 | 55 | val setter = DefDef( 56 | mods = Modifiers(Flags.METHOD) withPosition (Flags.METHOD, NoPosition), 57 | name = newTermName(publicName +"_="), 58 | tparams = Nil, 59 | vparamss = List(List(ValDef(Modifiers(Flags.PARAM), newTermName(publicName), TypeTree(selectedValue.tpt.tpe), EmptyTree))), 60 | tpt = EmptyTree, 61 | rhs = Block( 62 | Assign( 63 | Ident(privateName), 64 | Ident(publicName)) :: Nil, EmptyTree)) 65 | 66 | val insertGetterSettersTransformation = transform { 67 | 68 | case tpl: Template if tpl == template => 69 | 70 | val classParameters = tpl.body.map { 71 | case t: ValDef if t == selectedValue => privateField setPos t.pos 72 | case t => t 73 | } 74 | 75 | val body = if(createSetter) 76 | getter :: setter :: classParameters 77 | else 78 | getter :: classParameters 79 | 80 | tpl.copy(body = body) setPos tpl.pos 81 | } 82 | 83 | Right(transformFile(selection.file, topdown(matchingChildren(insertGetterSettersTransformation)))) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/ExtractMethod.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package implementations 7 | 8 | import scala.tools.nsc.symtab.Flags 9 | 10 | import analysis.Indexes 11 | import analysis.TreeAnalysis 12 | import common.Change 13 | import common.InteractiveScalaCompiler 14 | import transformation.TreeFactory 15 | 16 | abstract class ExtractMethod extends MultiStageRefactoring with TreeAnalysis with Indexes with TreeFactory with InteractiveScalaCompiler { 17 | 18 | val global: tools.nsc.interactive.Global 19 | import global._ 20 | 21 | type PreparationResult = Tree 22 | 23 | type RefactoringParameters = String 24 | 25 | def prepare(s: Selection) = { 26 | s.findSelectedOfType[DefDef] match { 27 | case _ if s.selectedTopLevelTrees.isEmpty => 28 | Left(PreparationError("No expressions or statements selected.")) 29 | case Some(tree) => 30 | Right(tree) 31 | case None => 32 | Left(PreparationError("No enclosing method definition found: please select code that's inside a method.")) 33 | } 34 | } 35 | 36 | def perform(selection: Selection, selectedMethod: PreparationResult, methodName: RefactoringParameters): Either[RefactoringError, List[Change]] = { 37 | 38 | val (call, newDef) = { 39 | 40 | val deps = { 41 | val inboundDeps = inboundLocalDependencies(selection, selectedMethod.symbol) 42 | selection.selectedTopLevelTrees match { 43 | /* extracting the condition of a for-expression */ 44 | case List(t: Function) if t.pos.isTransparent => 45 | t.vparams.map(_.symbol) ::: inboundDeps 46 | case _ => 47 | inboundDeps 48 | } 49 | } 50 | 51 | val parameters = { 52 | if(deps.isEmpty) 53 | Nil // no argument list 54 | else 55 | deps :: Nil // single argument list with all parameters 56 | } 57 | 58 | val returns = outboundLocalDependencies(selection) 59 | 60 | val returnStatement = if(returns.isEmpty) Nil else mkReturn(returns) :: Nil 61 | 62 | val newDef = mkDefDef(NoMods withPosition (Flags.PRIVATE, NoPosition), methodName, parameters, selection.selectedTopLevelTrees ::: returnStatement) 63 | 64 | val call = mkCallDefDef(methodName, deps :: Nil, returns) 65 | 66 | (call, newDef) 67 | } 68 | 69 | val extractSingleStatement = selection.selectedTopLevelTrees.size == 1 70 | 71 | val findTemplate = filter { 72 | case Template(_, _, body) => 73 | body exists (_ == selectedMethod) 74 | } 75 | 76 | val findMethod = filter { 77 | case d: DefDef => d == selectedMethod 78 | } 79 | 80 | val replaceBlockOfStatements = topdown { 81 | matchingChildren { 82 | transform { 83 | case block @ BlockExtractor(stats) if stats.nonEmpty => { 84 | val newStats = stats.replaceSequence(selection.selectedTopLevelTrees, call :: Nil) 85 | mkBlock(newStats) replaces block 86 | } 87 | } 88 | } 89 | } 90 | 91 | val replaceExpression = if(extractSingleStatement) 92 | replaceTree(selection.selectedTopLevelTrees.head, call) 93 | else 94 | fail[Tree] 95 | 96 | val insertMethodCall = transform { 97 | case tpl @ Template(_, _, body) => 98 | val p = selectedMethod.pos.point 99 | val (before, after) = body.span { t => 100 | !t.pos.isRange /* to skip synthetic methods*/ || t.pos.point <= p 101 | } 102 | tpl copy(body = before ::: newDef :: after) replaces tpl 103 | } 104 | 105 | val extractMethod = topdown { 106 | matchingChildren { 107 | findTemplate &> 108 | topdown { 109 | matchingChildren { 110 | findMethod &> replaceBlockOfStatements |> replaceExpression 111 | } 112 | } &> 113 | insertMethodCall 114 | } 115 | } 116 | 117 | Right(transformFile(selection.file, extractMethod)) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/GenerateHashcodeAndEquals.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | /** 5 | * Refactoring that generates hashCode and equals implementations 6 | * following the recommendations given in chapter 28 of 7 | * Programming in Scala. 8 | */ 9 | abstract class GenerateHashcodeAndEquals extends ClassParameterDrivenSourceGeneration { 10 | 11 | import global._ 12 | 13 | override def failsBecause(classDef: ClassDef) = { 14 | None 15 | } 16 | 17 | override def sourceGeneration(selectedParams: List[ValDef], preparationResult: PreparationResult, refactoringParams: RefactoringParameters) = { 18 | 19 | val equalityMethods = mkEqualityMethods(preparationResult.classDef.symbol, selectedParams, refactoringParams.callSuper) 20 | val newParents = newParentNames(preparationResult.classDef, selectedParams).map(name => Ident(newTermName(name))) 21 | 22 | val equalityMethodNames = List(nme.equals_, nme.hashCode_, nme.canEqual_).map(_.toString) 23 | def isEqualityMethod(t: Tree) = t match { 24 | case d: ValOrDefDef => equalityMethodNames contains d.nameString 25 | case _ => false 26 | } 27 | 28 | def addEqualityMethods = transform { 29 | case t @ Template(parents, self, body) => { 30 | val bodyFilter: Tree => Boolean = if (refactoringParams.keepExistingEqualityMethods) 31 | (t: Tree) => true else (t: Tree) => !isEqualityMethod(t) 32 | val filteredBody = body.filter(bodyFilter) 33 | val equalityMethodsInBody = filteredBody collect {case d: ValOrDefDef if equalityMethodNames contains d.nameString => d.name } 34 | val filteredEqualityMethods = equalityMethods.filter(e => !(equalityMethodsInBody contains e.name)) 35 | Template(parents:::newParents, self, filteredBody:::filteredEqualityMethods) replaces t 36 | } 37 | } 38 | 39 | addEqualityMethods 40 | } 41 | 42 | private def mkEqualityMethods(classSymbol: Symbol, params: List[ValDef], callSuper: Boolean) = { 43 | val canEqual = mkCanEqual(classSymbol) 44 | val hashcode = mkHashcode(classSymbol, params, callSuper) 45 | val equals = mkEquals(classSymbol, params, callSuper) 46 | 47 | canEqual::equals::hashcode::Nil 48 | } 49 | 50 | protected def newParentNames(classDef: ClassDef, selectedParams: List[ValDef]): List[String] = { 51 | val existingParents = classDef.impl.parents.map(_.nameString) 52 | if(existingParents contains "Equals") 53 | Nil 54 | else 55 | "Equals"::Nil 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/ImportsHelper.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | import scala.language.reflectiveCalls 5 | import scala.tools.refactoring.common.TracingImpl 6 | import scala.tools.refactoring.implementations.oimports.OrganizeImportsWorker 7 | 8 | trait ImportsHelper extends TracingImpl { 9 | 10 | self: common.InteractiveScalaCompiler with analysis.Indexes with transformation.Transformations with transformation.TreeTransformations with common.EnrichedTrees => 11 | 12 | import global._ 13 | 14 | def addRequiredImports(importsUser: Option[Tree], targetPackageName: Option[String]) = traverseAndTransformAll { 15 | findBestPackageForImports &> transformation[(PackageDef, List[Import], List[Tree]), Tree] { 16 | case (pkg, existingImports, rest) => { 17 | val user = importsUser getOrElse pkg 18 | val targetPkgName = targetPackageName getOrElse pkg.nameString 19 | val oi = new OrganizeImportsWorker[global.type](global) { 20 | import participants._ 21 | import regionContext._ 22 | import transformations._ 23 | import treeToolbox._ 24 | object NeededImports extends Participant { 25 | def doApply(trees: List[Import]) = { 26 | val externalDependencies = neededImports(user) filterNot { imp => 27 | // We don't want to add imports for types that are 28 | // children of `importsUser`. 29 | index.declaration(imp.symbol).exists { declaration => 30 | val sameFile = declaration.pos.source.file.canonicalPath == user.pos.source.file.canonicalPath 31 | def userPosIncludesDeclPos = user.pos.includes(declaration.pos) 32 | sameFile && userPosIncludesDeclPos 33 | } 34 | } 35 | 36 | val newImportsToAdd = externalDependencies filterNot { 37 | case Select(qualifier, name) => 38 | val depPkgStr = importAsString(qualifier) 39 | val depNameStr = "" + name 40 | 41 | trees exists { 42 | case Import(expr, selectors) => 43 | val impPkgStr = importAsString(expr) 44 | 45 | selectors exists { selector => 46 | val selNameStr = "" + selector.name 47 | val selRenameStr = "" + selector.rename 48 | 49 | impPkgStr == depPkgStr && { 50 | selector.name == nme.WILDCARD || { 51 | selNameStr == depNameStr || selRenameStr == depNameStr 52 | } 53 | } 54 | } 55 | } 56 | } 57 | val existingStillNeededImports = new recomputeAndModifyUnused(user)( 58 | List(Region(trees.map { imp => new RegionImport(proto = imp)() }))) 59 | 60 | existingStillNeededImports.flatMap { _.imports } ::: SortImports(mkImportTrees(newImportsToAdd, targetPkgName)) 61 | } 62 | } 63 | } 64 | 65 | val imports = oi.NeededImports(existingImports).filterNot { imp => 66 | val noLongerNeeded = targetPkgName == imp.expr.toString && imp.selectors.size == 1 && { 67 | // Note that we don't touch imports with multiple selectors here. This limitation, 68 | // that should not result in any regressions, might be addressed in the future. 69 | val s = imp.selectors.head 70 | s.name == s.rename 71 | } 72 | 73 | noLongerNeeded 74 | } 75 | 76 | // When we move the whole file, we only want to add imports to the originating package 77 | pkg copy (stats = imports ::: rest) replaces pkg 78 | } 79 | } 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/InlineLocal.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package implementations 7 | 8 | import common.Change 9 | import transformation.TreeFactory 10 | import analysis.Indexes 11 | 12 | abstract class InlineLocal extends MultiStageRefactoring with ParameterlessRefactoring with TreeFactory with Indexes with common.InteractiveScalaCompiler { 13 | 14 | import global._ 15 | 16 | override type PreparationResult = ValDef 17 | 18 | override def prepare(s: Selection) = { 19 | 20 | val selectedValue = s.findSelectedOfType[RefTree] match { 21 | case Some(t) => 22 | index.declaration(t.symbol) match { 23 | case Some(v: ValDef) => Some(v) 24 | case _ => None 25 | } 26 | case None => s.findSelectedOfType[ValDef] 27 | } 28 | 29 | def isInliningAllowed(sym: Symbol) = 30 | (sym.isPrivate || sym.isLocal) && !sym.isMutable && !sym.isValueParameter 31 | 32 | selectedValue match { 33 | case Some(t) if isInliningAllowed(t.symbol) => 34 | Right(t) 35 | case Some(t) => 36 | Left(PreparationError("The selected value cannot be inlined.")) 37 | case None => 38 | Left(PreparationError("No local value selected.")) 39 | } 40 | } 41 | 42 | override def perform(selection: Selection, selectedValue: PreparationResult): Either[RefactoringError, List[Change]] = { 43 | 44 | trace("Selected: %s", selectedValue) 45 | 46 | val removeSelectedValue = { 47 | 48 | def replaceSelectedValue(ts: List[Tree]) = { 49 | ts replaceSequence (List(selectedValue), Nil) 50 | } 51 | 52 | transform { 53 | case tpl @ Template(_, _, stats) if stats contains selectedValue => 54 | tpl.copy(body = replaceSelectedValue(stats)) replaces tpl 55 | case block @ BlockExtractor(stats) if stats contains selectedValue => 56 | mkBlock(replaceSelectedValue(stats)) replaces block 57 | } 58 | } 59 | 60 | val references = index references selectedValue.symbol 61 | 62 | val replaceReferenceWithRhs = { 63 | 64 | val replacement = selectedValue.rhs match { 65 | // inlining `list.filter _` should not include the `_` 66 | case Function(vparams, Apply(fun, args)) if vparams forall (_.symbol.isSynthetic) => fun 67 | case t => t 68 | } 69 | 70 | trace("Value is referenced on lines: %s", references map (_.pos.lineContent) mkString "\n ") 71 | 72 | transform { 73 | case t if references contains t => replacement 74 | } 75 | } 76 | 77 | if(references.isEmpty) { 78 | Left(RefactoringError("No references to selected val found.")) 79 | } else { 80 | Right(transformFile(selection.file, topdown(matchingChildren(removeSelectedValue &> topdown(matchingChildren(replaceReferenceWithRhs)))))) 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/IntroduceProductNTrait.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | 5 | /** 6 | * Refactoring that implements the ProductN trait for a class. 7 | * Given N selected class parameters this refactoring generates 8 | * the methods needed to implement the ProductN trait. This includes 9 | * implementations for hashCode and equals. 10 | * @see GenerateHashcodeAndEquals 11 | */ 12 | abstract class IntroduceProductNTrait extends GenerateHashcodeAndEquals { 13 | 14 | import global._ 15 | 16 | override def sourceGeneration(selectedParams: List[ValDef], preparationResult: PreparationResult, refactoringParams: RefactoringParameters) = { 17 | val superGeneration = super.sourceGeneration(selectedParams, preparationResult, refactoringParams) 18 | 19 | val projections = { 20 | def makeElemProjection(elem: ValDef, pos: Int) = { 21 | val body = List(Ident(elem.name)) 22 | mkDefDef(name = "_" + pos, body = body) 23 | } 24 | 25 | selectedParams.zipWithIndex.map(t => makeElemProjection(t._1, t._2 + 1)) 26 | } 27 | 28 | def addProductTrait = transform ({ 29 | case t @ Template(_, _, body) => t.copy(body = projections:::body) replaces t 30 | }) 31 | 32 | superGeneration &> addProductTrait 33 | } 34 | 35 | override def newParentNames(classDef: ClassDef, selectedParams: List[ValDef]) = { 36 | val arity = selectedParams.length 37 | val paramsTypenames = selectedParams.map(v => v.tpt.nameString) 38 | val productParent = "Product" + arity + "[" + paramsTypenames.mkString(", ") + "]" 39 | productParent::Nil 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/MergeParameterLists.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | /** 5 | * Refactoring to merge parameter lists of a method. 6 | */ 7 | abstract class MergeParameterLists extends MethodSignatureRefactoring { 8 | 9 | import global._ 10 | 11 | type MergePositions = List[Int] 12 | type RefactoringParameters = MergePositions 13 | 14 | override def checkRefactoringParams(prep: PreparationResult, affectedDefs: AffectedDefs, params: RefactoringParameters) = { 15 | val selectedDefDef = prep.defdef 16 | val allowedMergeIndexesRange = 1 until selectedDefDef.vparamss.size 17 | val isNotEmpty = (p: RefactoringParameters) => !p.isEmpty 18 | val isSorted = (p: RefactoringParameters) => (p sortWith (_ < _)) == p 19 | val uniqueIndexes = (p: RefactoringParameters) => p.distinct == p 20 | val indexesInRange = (p: RefactoringParameters) => allowedMergeIndexesRange containsSlice (p.head to p.last) 21 | val mergeable = (p: RefactoringParameters) => { 22 | val allAffectedDefs = affectedDefs.originals:::affectedDefs.partials 23 | val preparedParams = allAffectedDefs.map(prepareParamsForSingleRefactoring(params, selectedDefDef, _)) 24 | preparedParams.filter(_ contains 0).isEmpty 25 | } 26 | val allConditions = List(isNotEmpty, isSorted, uniqueIndexes, indexesInRange, mergeable) 27 | allConditions.foldLeft(true)((b, f) => b && f(params)) 28 | } 29 | 30 | override def defdefRefactoring(params: RefactoringParameters) = transform { 31 | case orig @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => { 32 | val vparamssWithIndex = vparamss.zipWithIndex 33 | val mergedVparamss = vparamssWithIndex.foldLeft(Nil: List[List[ValDef]])((acc, current) => current match { 34 | case (_, index) if params contains index => (acc.head:::current._1)::acc.tail 35 | case _ => current._1::acc 36 | }).reverse 37 | DefDef(mods, name, tparams, mergedVparamss, tpt, rhs) replaces orig 38 | } 39 | } 40 | 41 | override def applyRefactoring(params: RefactoringParameters) = transform { 42 | case apply @ Apply(fun, args) => { 43 | val originalTree = findOriginalTree(apply) 44 | val pos = paramListPos(originalTree) - 1 45 | if(params contains pos) { 46 | fun match { 47 | case Apply(ffun, fargs) => Apply(ffun, fargs:::args) 48 | case Select(Apply(ffun, fargs), name) => Apply(ffun, fargs:::args) 49 | case _ => apply 50 | } 51 | } else { 52 | apply 53 | } 54 | } 55 | } 56 | 57 | override def traverseApply(t: ⇒ Transformation[Tree, Tree]) = bottomup(t) 58 | 59 | override def prepareParamsForSingleRefactoring(originalParams: RefactoringParameters, selectedMethod: DefDef, toRefactor: DefInfo): RefactoringParameters = { 60 | val originalNrParamLists = selectedMethod.vparamss.size 61 | val currentNrParamLists = toRefactor.nrParamLists 62 | val untouchables = toRefactor.nrUntouchableParamLists 63 | val toShift = originalNrParamLists - currentNrParamLists - untouchables 64 | originalParams.map(_ - toShift).filter(_ >= untouchables) 65 | } 66 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/SplitParameterLists.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations 3 | 4 | /** 5 | * Refactoring to split parameter lists of a method. 6 | */ 7 | abstract class SplitParameterLists extends MethodSignatureRefactoring { 8 | 9 | import global._ 10 | 11 | type SplitPositions = List[Int] 12 | /** 13 | * Split positions must be provided for every parameter list of the method (though they can be Nil) 14 | */ 15 | type RefactoringParameters = List[SplitPositions] 16 | 17 | override def checkRefactoringParams(prep: PreparationResult, affectedDefs: AffectedDefs, params: RefactoringParameters) = { 18 | def checkRefactoringParamsHelper(vparamss: List[List[ValDef]], sectionss: List[SplitPositions]): Boolean = { 19 | val sortedSections = sectionss.map(Set(_: _*).toList.sorted) 20 | if(sortedSections != sectionss || vparamss.size != sectionss.size) { 21 | false 22 | } else { 23 | val emptyRange = 1 to 0 24 | val sectionRanges = sectionss map { case Nil => emptyRange ; case s => s.head to s.last } 25 | val vparamsRanges = vparamss.map(1 until _.size) 26 | (vparamsRanges zip sectionRanges).foldLeft(true)((b, ranges) => b && (ranges._1 containsSlice ranges._2)) 27 | } 28 | } 29 | 30 | checkRefactoringParamsHelper(prep.defdef.vparamss, params) 31 | } 32 | 33 | def splitSingleParamList[T](origVparams: List[T], positions: SplitPositions): List[List[T]] = { 34 | val nrParamsPerList = (positions:::List(origVparams.length) zip 0::positions) map (t => t._1 - t._2) 35 | nrParamsPerList.foldLeft((Nil: List[List[T]] , origVparams))((acc, nrParams) => { 36 | val (currentCurriedParamList, remainingOrigParams) = acc._2 splitAt nrParams 37 | (acc._1:::List(currentCurriedParamList), remainingOrigParams) 38 | })._1 39 | } 40 | 41 | def makeSplitApply(baseFun: Tree, vparamss: List[List[Tree]]) = { 42 | val firstApply = Apply(baseFun, vparamss.headOption.getOrElse(throw new IllegalArgumentException("can't handle empty vparamss"))) 43 | vparamss.tail.foldLeft(firstApply)((fun, vparams) => Apply(fun, vparams)) 44 | } 45 | 46 | override def defdefRefactoring(params: RefactoringParameters) = transform { 47 | case orig @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => { 48 | val split = (vparamss zip params) flatMap (l => splitSingleParamList(l._1, l._2)) 49 | DefDef(mods, name, tparams, split, tpt, rhs) replaces orig 50 | } 51 | } 52 | 53 | override def applyRefactoring(params: RefactoringParameters) = transform { 54 | case apply @ Apply(fun, args) => { 55 | val originalTree = findOriginalTree(apply) 56 | val pos = paramListPos(originalTree) - 1 57 | val splitParamLists = splitSingleParamList(apply.args, params(pos)) 58 | val splitApply = makeSplitApply(fun, splitParamLists) 59 | splitApply replaces apply 60 | } 61 | } 62 | 63 | override def traverseApply(t: ⇒ Transformation[Tree, Tree]) = bottomup(t) 64 | 65 | override def prepareParamsForSingleRefactoring(originalParams: RefactoringParameters, selectedMethod: DefDef, toRefactor: DefInfo): RefactoringParameters = { 66 | val toDrop = originalParams.size - toRefactor.nrParamLists 67 | val preparedParams = originalParams.drop(toDrop) 68 | val noSplitters = (1 to toRefactor.nrUntouchableParamLists).toList 69 | noSplitters match { 70 | case Nil => preparedParams 71 | case _ => noSplitters.map(_ => Nil):::preparedParams 72 | } 73 | 74 | 75 | } 76 | 77 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/extraction/ExtractCode.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.implementations.extraction 2 | 3 | /** 4 | * General extraction refactoring that proposes different concrete extractions based on the 5 | * current selection. 6 | */ 7 | abstract class ExtractCode extends ExtractionRefactoring with AutoExtractions { 8 | val collector = AutoExtraction 9 | } 10 | 11 | trait AutoExtractions extends MethodExtractions with ValueExtractions with ExtractorExtractions with ParameterExtractions { 12 | /** 13 | * Proposes different kinds of extractions. 14 | */ 15 | object AutoExtraction extends ExtractionCollector[Extraction] { 16 | /** 17 | * Extraction collectors used for auto extraction. 18 | */ 19 | val availableCollectors = 20 | ExtractorExtraction :: 21 | ValueExtraction :: 22 | MethodExtraction :: 23 | ParameterExtraction :: 24 | Nil 25 | 26 | /** 27 | * Searches for an extraction source that is valid for at least one 28 | * extraction collector. If an appropriate source is found it calls 29 | * `collect(s)` on every collector that accepts this source. 30 | */ 31 | override def collect(s: Selection) = { 32 | var applicableCollectors: List[ExtractionCollector[Extraction]] = Nil 33 | val sourceOpt = s.expand.expandTo { source: Selection => 34 | applicableCollectors = availableCollectors.filter(_.isValidExtractionSource(source)) 35 | !applicableCollectors.isEmpty 36 | } 37 | 38 | val extractions = applicableCollectors.flatMap { collector => 39 | collector.collect(sourceOpt.get).right.getOrElse(Nil) 40 | } 41 | 42 | if (extractions.isEmpty) 43 | Left("No applicable extraction found.") 44 | else 45 | Right(extractions.sortBy(-_.extractionTarget.enclosing.pos.startOrPoint)) 46 | } 47 | 48 | def isValidExtractionSource(s: Selection) = ??? 49 | 50 | def createExtractions(source: Selection, targets: List[ExtractionTarget], name: String) = ??? 51 | } 52 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/extraction/ExtractParameter.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.implementations.extraction 2 | 3 | import scala.tools.refactoring.analysis.ImportAnalysis 4 | 5 | abstract class ExtractParameter extends ExtractionRefactoring with ParameterExtractions { 6 | val collector = ParameterExtraction 7 | } 8 | 9 | /** 10 | * Extracts an expression into a new parameter whose default value is 11 | * the extracted expression. 12 | */ 13 | trait ParameterExtractions extends Extractions with ImportAnalysis { 14 | import global._ 15 | 16 | object ParameterExtraction extends ExtractionCollector[ParameterExtraction] { 17 | def isValidExtractionSource(s: Selection) = 18 | s.representsValue && !s.representsParameter 19 | 20 | override def createInsertionPosition(s: Selection) = 21 | atEndOfValueParameterList 22 | 23 | def createExtractions(source: Selection, targets: List[ExtractionTarget], name: String) = { 24 | val validTargets = targets.takeWhile { t => 25 | source.inboundLocalDeps.forall(t.scope.sees(_)) 26 | } 27 | 28 | validTargets.map(ParameterExtraction(source, _, name)) 29 | } 30 | } 31 | 32 | case class ParameterExtraction( 33 | extractionSource: Selection, 34 | extractionTarget: ExtractionTarget, 35 | abstractionName: String) extends Extraction { 36 | 37 | val displayName = extractionTarget.enclosing match { 38 | case t: DefDef => s"Extract Parameter to Method ${t.symbol.nameString}" 39 | } 40 | 41 | val functionOrDefDef = extractionTarget.enclosing 42 | 43 | def perform() = { 44 | val tpe = defaultVal.tpe 45 | val param = mkParam(abstractionName, tpe, defaultVal) 46 | val paramRef = Ident(newTermName(abstractionName)) 47 | 48 | extractionSource.replaceBy(paramRef) :: 49 | extractionTarget.insert(param) :: 50 | Nil 51 | } 52 | 53 | val defaultVal = { 54 | extractionSource.selectedTopLevelTrees.last 55 | } 56 | 57 | def withAbstractionName(name: String) = 58 | copy(abstractionName = name).asInstanceOf[this.type] 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/extraction/ExtractValue.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.implementations.extraction 2 | 3 | import scala.tools.refactoring.analysis.ImportAnalysis 4 | 5 | /** 6 | * Extracts one or more expressions into a new val definition. 7 | */ 8 | abstract class ExtractValue extends ExtractionRefactoring with ValueExtractions { 9 | val collector = ValueExtraction 10 | } 11 | 12 | trait ValueExtractions extends Extractions with ImportAnalysis { 13 | import global._ 14 | 15 | object ValueExtraction extends ExtractionCollector[ValueExtraction] { 16 | def isValidExtractionSource(s: Selection) = 17 | (s.representsValue || s.representsValueDefinitions) && !s.representsParameter 18 | 19 | def createExtractions(source: Selection, targets: List[ExtractionTarget], name: String) = { 20 | val validTargets = targets.takeWhile { t => 21 | source.inboundLocalDeps.forall(t.scope.sees(_)) 22 | } 23 | 24 | validTargets.map(ValueExtraction(source, _, name)) 25 | } 26 | } 27 | 28 | case class ValueExtraction( 29 | extractionSource: Selection, 30 | extractionTarget: ExtractionTarget, 31 | abstractionName: String) extends Extraction { 32 | 33 | val displayName = extractionTarget.enclosing match { 34 | case t: Template => s"Extract Value to ${t.symbol.owner.decodedName}" 35 | case _ => "Extract Local Value" 36 | } 37 | 38 | def withAbstractionName(name: String) = 39 | copy(abstractionName = name).asInstanceOf[this.type] 40 | 41 | lazy val imports = buildImportTree(extractionSource.root) 42 | 43 | def perform() = { 44 | val outboundDeps = extractionSource.outboundLocalDeps 45 | val call = mkCallValDef(abstractionName, outboundDeps) 46 | 47 | val returnStatements = 48 | if (outboundDeps.isEmpty) Nil 49 | else mkReturn(outboundDeps) :: Nil 50 | 51 | val importStatements = extractionSource.selectedTopLevelTrees.flatMap(imports.findRequiredImports(_, extractionSource.pos, extractionTarget.pos)) 52 | 53 | val extractedStatements = extractionSource.selectedTopLevelTrees match { 54 | case (fn @ Function(vparams, Block(stmts, expr))) :: Nil => 55 | val tpe = fn.tpe 56 | val newFn = fn copy (body = Block(stmts, expr)) 57 | newFn.tpe = tpe 58 | newFn :: Nil 59 | case ts => ts 60 | } 61 | 62 | val statements = importStatements ::: extractedStatements ::: returnStatements 63 | 64 | val abstraction = statements match { 65 | case expr :: Nil => 66 | expr match { 67 | // Add explicit type annotations for extracted functions 68 | case fn: Function => 69 | mkValDef(abstractionName, fn, TypeTree(fn.tpe)) 70 | case t => 71 | mkValDef(abstractionName, t) 72 | } 73 | case stmts => mkValDef(abstractionName, mkBlock(stmts)) 74 | } 75 | 76 | extractionSource.replaceBy(call, preserveHierarchy = true) :: 77 | extractionTarget.insert(abstraction) :: 78 | Nil 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/implementations/oimports/ImportsOrganizer.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package implementations.oimports 3 | 4 | import scala.tools.nsc.Global 5 | 6 | import sourcegen.Formatting 7 | 8 | abstract class ImportsOrganizer[G <: Global, U <: TreeToolbox[G]](val treeToolbox: U) { 9 | import treeToolbox.global._ 10 | type T <: Tree 11 | 12 | protected def noAnyTwoImportsInSameLine(importsGroup: List[Import]): Boolean = 13 | importsGroup.size == importsGroup.map { _.pos.line }.distinct.size 14 | 15 | private def importsGroupsFromTree(trees: List[Tree]): List[List[Import]] = { 16 | val groupedImports = trees.foldLeft(List.empty[List[Import]]) { (acc, tree) => 17 | tree match { 18 | case imp: Import => 19 | val lastUpdated = acc.lastOption.map { _ :+ imp }.getOrElse(List(imp)) 20 | acc.take(acc.length - 1) :+ lastUpdated 21 | case _ => acc :+ List.empty[Import] 22 | } 23 | }.filter { _.nonEmpty } 24 | groupedImports 25 | } 26 | 27 | protected def forTreesOf(tree: Tree): List[(T, Symbol)] 28 | 29 | protected def treeChildren(parent: T): List[Tree] 30 | 31 | private def toRegions(groupedImports: List[List[Import]], importsOwner: Symbol, formatting: Formatting): List[treeToolbox.Region] = 32 | groupedImports.collect { 33 | case imports @ h :: _ => RegionBuilder[G, U](treeToolbox)(imports, importsOwner, formatting, "") 34 | }.flatten 35 | 36 | def transformTreeToRegions(tree: Tree, formatting: Formatting): List[treeToolbox.Region] = forTreesOf(tree).flatMap { 37 | case (extractedTree, treeOwner) => 38 | val groupedImports = importsGroupsFromTree(treeChildren(extractedTree)).filter { 39 | noAnyTwoImportsInSameLine 40 | } 41 | toRegions(groupedImports, treeOwner, formatting) 42 | } 43 | } 44 | 45 | class DefImportsOrganizer[G <: Global, U <: TreeToolbox[G]](override val treeToolbox: U) extends ImportsOrganizer[G, U](treeToolbox) { 46 | import treeToolbox.global._ 47 | type T = Block 48 | import treeToolbox.forTreesOfKind 49 | 50 | override protected def forTreesOf(tree: Tree) = forTreesOfKind[Block](tree) { treeCollector => 51 | { 52 | case b @ Block(stats, expr) if treeCollector.currentOwner.isMethod && !treeCollector.currentOwner.isLazy => 53 | treeCollector.collect(b) 54 | stats.foreach { treeCollector.traverse } 55 | treeCollector.traverse(expr) 56 | } 57 | } 58 | 59 | override protected def treeChildren(block: Block) = block.stats 60 | } 61 | 62 | class ClassDefImportsOrganizer[G <: Global, U <: TreeToolbox[G]](override val treeToolbox: U) extends ImportsOrganizer[G, U](treeToolbox) { 63 | import treeToolbox.global._ 64 | type T = Template 65 | import treeToolbox.forTreesOfKind 66 | 67 | override protected def forTreesOf(tree: Tree) = forTreesOfKind[Template](tree) { treeCollector => 68 | { 69 | case t @ Template(_, _, body) => 70 | treeCollector.collect(t) 71 | body.foreach { treeCollector.traverse } 72 | } 73 | } 74 | 75 | override protected def treeChildren(template: Template) = template.body 76 | } 77 | 78 | class PackageDefImportsOrganizer[G <: Global, U <: TreeToolbox[G]](override val treeToolbox: U) extends ImportsOrganizer[G, U](treeToolbox) { 79 | import treeToolbox.global._ 80 | type T = PackageDef 81 | import treeToolbox.forTreesOfKind 82 | 83 | override protected def forTreesOf(tree: Tree) = forTreesOfKind[PackageDef](tree) { treeCollector => 84 | { 85 | case p @ PackageDef(pid, stats) => 86 | treeCollector.collect(p, p.symbol.asTerm.referenced) 87 | stats.foreach { treeCollector.traverse } 88 | } 89 | } 90 | 91 | override protected def noAnyTwoImportsInSameLine(importsGroup: List[Import]): Boolean = true 92 | override protected def treeChildren(packageDef: PackageDef) = packageDef.stats 93 | } 94 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/package.scala: -------------------------------------------------------------------------------- 1 | package scala.tools 2 | 3 | import scala.tools.nsc.interactive.PresentationCompilerThread 4 | 5 | package object refactoring { 6 | 7 | /** 8 | * Asserts that the current operation is running on the thread 9 | * of the presentation compiler (PC). This is necessary because many 10 | * operations on compiler symbols can trigger further compilation, 11 | * which needs to be done on the PC thread. 12 | * 13 | * To run an operation on the PC thread, use global.ask { .. } 14 | */ 15 | def assertCurrentThreadIsPresentationCompiler(): Unit = { 16 | val msg = "operation should be running on the presentation compiler thread" 17 | assert(Thread.currentThread.isInstanceOf[PresentationCompilerThread], msg) 18 | } 19 | 20 | /** 21 | * Safe way to get a simple class name from an object. 22 | * 23 | * Using getClass.getSimpleName can sometimes lead to InternalError("Malformed class name") 24 | * being thrown, so we catch that. Probably related to #SI-2034 25 | */ 26 | def getSimpleClassName(o: Object): String = try { 27 | o.getClass.getSimpleName 28 | } catch { 29 | case _: InternalError | _: NoClassDefFoundError => o.getClass.getName 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package sourcegen 7 | 8 | import scala.reflect.internal.util.SourceFile 9 | 10 | trait AbstractPrinter extends CommonPrintUtils { 11 | 12 | this: common.Tracing with common.EnrichedTrees with Indentations with common.CompilerAccess with Formatting => 13 | 14 | import global._ 15 | 16 | /** 17 | * PrintingContext is passed around with all the print methods and contains 18 | * the context or environment for the current printing. 19 | */ 20 | case class PrintingContext(ind: Indentation, changeSet: ChangeSet, parent: Tree, file: Option[SourceFile]) { 21 | lazy val newline: String = { 22 | if(file.exists(_.content.containsSlice("\r\n"))) 23 | "\r\n" 24 | else 25 | "\n" 26 | } 27 | } 28 | 29 | trait ChangeSet { 30 | def hasChanged(t: Tree): Boolean 31 | } 32 | 33 | def print(t: Tree, ctx: PrintingContext): Fragment 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/sourcegen/CommentsUtils.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.sourcegen 2 | 3 | /** 4 | * Only here for backward compatibility - use [[SourceUtils]] instead 5 | */ 6 | object CommentsUtils extends SourceUtils 7 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/sourcegen/CommonPrintUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package sourcegen 7 | 8 | import scala.reflect.internal.util.BatchSourceFile 9 | 10 | trait CommonPrintUtils { 11 | 12 | this: common.CompilerAccess with AbstractPrinter => 13 | 14 | import global._ 15 | 16 | def newline(implicit ctx: PrintingContext) = Requisite.newline("", ctx.newline) 17 | 18 | def indentedNewline(implicit ctx: PrintingContext) = Requisite.newline(ctx.ind.current, ctx.newline) 19 | 20 | def newlineIndentedToChildren(implicit ctx: PrintingContext) = Requisite.newline(ctx.ind.incrementDefault.current, ctx.newline) 21 | 22 | def indentation(implicit ctx: PrintingContext) = ctx.ind.current 23 | 24 | def typeToString(tree: TypeTree, t: Type)(implicit ctx: PrintingContext): String = { 25 | t match { 26 | case tpe if tpe == EmptyTree.tpe => "" 27 | case tpe: ConstantType => 28 | tpe.typeSymbol.tpe.toString 29 | case tpe: TypeRef if tree.original != null && tpe.sym.nameString.matches("Tuple\\d+") => 30 | tpe.toString 31 | case tpe if tree.original != null && !tpe.isInstanceOf[TypeRef] => 32 | print(tree.original, ctx).asText 33 | case tpe: RefinedType => 34 | tpe.typeSymbol.tpe.toString 35 | case typeRef @ TypeRef(_, _, arg1 :: ret :: Nil) if definitions.isFunctionType(typeRef) => 36 | typeToString(tree, arg1) + " => " + typeToString(tree, ret) 37 | case MethodType(params, result) => 38 | val printedParams = params.map(s => typeToString(tree, s.tpe)).mkString(", ") 39 | val printedResult = typeToString(tree, result) 40 | 41 | if (params.isEmpty) { 42 | "() => " + printedResult 43 | } else if (params.size > 1) { 44 | "(" + printedParams + ") => " + printedResult 45 | } else { 46 | printedParams + " => " + printedResult 47 | } 48 | 49 | case tpe => 50 | tpe.toString 51 | } 52 | } 53 | 54 | def balanceBracketsInLayout(open: Char, close: Char, l: Layout) = { 55 | balanceBrackets(open, close)(Fragment(l.asText)).toLayout 56 | } 57 | 58 | def balanceBrackets(open: Char, close: Char)(f: Fragment) = Fragment { 59 | val (opening, closing) = SourceUtils.countRelevantBrackets(f.toLayout.asText, open, close) 60 | if (opening > closing) { 61 | f.asText + (("" + close) * (opening - closing)) 62 | } else if (opening < closing) { 63 | (("" + open) * (closing - opening)) + f.asText 64 | } else { 65 | f.asText 66 | } 67 | } 68 | 69 | /** 70 | * When extracting source code from the file via a tree's position, 71 | * it depends on the tree type whether we can use the position's 72 | * start or point. 73 | * 74 | * @param t The tree that will be replaced. 75 | * @param p The position to adapt. This does not have to be the position of t. 76 | */ 77 | def adjustedStartPosForSourceExtraction(t: Tree, p: Position): Position = t match { 78 | case _: Select | _: New if t.pos.isRange && t.pos.start > t.pos.point => 79 | p withStart (p.start min p.point) 80 | case _ => 81 | p 82 | } 83 | 84 | lazy val precedence: Name => Int = { 85 | 86 | // Copied from the compiler 87 | def newUnitParser(code: String) = new syntaxAnalyzer.UnitParser(newCompilationUnit(code)) 88 | def newCompilationUnit(code: String) = new CompilationUnit(newSourceFile(code)) 89 | def newSourceFile(code: String) = new BatchSourceFile("", code) 90 | 91 | val parser = newUnitParser("") 92 | 93 | // I ♥ Scala 94 | name => parser.precedence(newTermName(name.decode)) 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/sourcegen/Formatting.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring.sourcegen 6 | 7 | import scala.util.Properties 8 | 9 | /** 10 | * Holds default formatting preferences. 11 | */ 12 | trait Formatting { 13 | 14 | /** 15 | * The characters that are used to indent changed code. 16 | */ 17 | def defaultIndentationStep = " " 18 | 19 | /** 20 | * The characters that surround an import with multiple 21 | * import selectors inside the braces: 22 | * 23 | * import a.{*name*} 24 | */ 25 | def spacingAroundMultipleImports = "" 26 | 27 | /** 28 | * If set to `true` printer of import should drop `scala.` prefix: 29 | * 30 | * `import scala.util.Try` should be printed as 31 | * 32 | * `import util.Try` 33 | */ 34 | def dropScalaPackage = false 35 | 36 | /** Used when new line is added to source file and EOL is needed. */ 37 | def lineDelimiter = Properties.lineSeparator 38 | } 39 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/sourcegen/Fragment.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package sourcegen 7 | 8 | trait Fragment { 9 | self => 10 | val leading: Layout 11 | val center: Layout 12 | val trailing: Layout 13 | 14 | val pre = NoRequisite: Requisite 15 | val post = NoRequisite: Requisite 16 | 17 | override def toString() = asText 18 | 19 | def dropLeadingLayout: Fragment = new Fragment { 20 | val leading = NoLayout 21 | val center = self.center 22 | val trailing = self.trailing 23 | override val pre = self.pre 24 | override val post = self.post 25 | } 26 | 27 | def dropLeadingIndentation: Fragment = new Fragment { 28 | val leading = Layout(self.leading.asText.replaceFirst("""^\s*""", "")) 29 | val center = self.center 30 | val trailing = self.trailing 31 | override val pre = self.pre 32 | override val post = self.post 33 | } 34 | 35 | def dropTrailingLayout: Fragment = new Fragment { 36 | val leading = self.leading 37 | val center = self.center 38 | val trailing = NoLayout 39 | override val pre = self.pre 40 | override val post = NoRequisite 41 | } 42 | 43 | def isEmpty: Boolean = this match { 44 | case EmptyFragment => true 45 | case _ if asText == "" => true 46 | case _ => false 47 | } 48 | 49 | def ifNotEmpty(f: Fragment => Fragment): Fragment = this match { 50 | case EmptyFragment => EmptyFragment 51 | case _ if asText == "" => EmptyFragment 52 | case _ => f(this) 53 | } 54 | 55 | def toLayout: Layout = new Layout { 56 | def asText = self.asText 57 | } 58 | 59 | def asText: String = pre(NoLayout, leading).asText + center.asText + post(trailing, NoLayout).asText 60 | 61 | /** 62 | * Combines two fragments, makes sure that 63 | * Requisites are satisfied. 64 | * 65 | * Combining two fragments (a,b,c) and (d,e,f) 66 | * yields a fragment (a,bcde,f). 67 | */ 68 | def ++ (o: Fragment): Fragment = o match { 69 | case EmptyFragment => this 70 | case _ => new Fragment { 71 | val leading = self.leading 72 | val center = self.center ++ (self.post ++ o.pre)(self.trailing, o.leading) ++ o.center 73 | val trailing = o.trailing 74 | 75 | override val pre = self.pre 76 | override val post = o.post 77 | } 78 | } 79 | 80 | /** 81 | * Combines a fragment with a layout, makes sure that 82 | * Requisites are satisfied. 83 | * 84 | * Combining (a,b,c) and (d) 85 | * yields a fragment (a,b,cd). 86 | */ 87 | def ++ (o: Layout): Fragment = o match { 88 | case NoLayout => this 89 | case _ => new Fragment { 90 | val leading = self.leading 91 | val center = self.center 92 | val trailing = self.post(self.trailing, o) 93 | 94 | override val pre = self.pre 95 | override val post = /*if (self.post.isRequired(this.trailing, NoLayout)) self.post else*/ NoRequisite 96 | } 97 | } 98 | 99 | def ++ (after: Requisite, before: Requisite = NoRequisite): Fragment = { 100 | new Fragment { 101 | val leading = self.leading 102 | val center = self.center 103 | val trailing = self.trailing 104 | 105 | override val pre = before ++ self.pre 106 | override val post = self.post ++ after 107 | } 108 | } 109 | } 110 | 111 | abstract class EmptyFragment extends Fragment { 112 | val leading = NoLayout: Layout 113 | val center = NoLayout: Layout 114 | val trailing = NoLayout: Layout 115 | } 116 | 117 | object EmptyFragment extends EmptyFragment 118 | 119 | object Fragment { 120 | 121 | def unapply(f: Fragment) = Some((f.leading, f.center, f.trailing)) 122 | 123 | def apply(l: Layout, c: Layout, t: Layout) = new Fragment { 124 | val leading = l 125 | val center = c 126 | val trailing = t 127 | } 128 | 129 | def apply(s: String) = new EmptyFragment { 130 | override val center = Layout(s) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/sourcegen/Indentation.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package sourcegen 7 | 8 | /** 9 | * A class that handles indentation and is passed between 10 | * the pretty printer and the source generator. 11 | * 12 | * defaultIncrement specifies how much the indentation should 13 | * be incremented for newly generated code (pretty printer). 14 | */ 15 | trait Indentations { 16 | 17 | this: common.Tracing => 18 | 19 | class Indentation(val defaultIncrement: String, val current: String) { 20 | 21 | def incrementDefault = new Indentation(defaultIncrement, current + defaultIncrement) 22 | 23 | def setTo(i: String) = new Indentation(defaultIncrement, i) 24 | 25 | def needsToBeFixed(oldIndentation: String, surroundingLayout: Layout*) = { 26 | oldIndentation != current && surroundingLayout.exists(_.contains("\n")) 27 | } 28 | 29 | def fixIndentation(code: String, oldIndentation: String) = { 30 | trace("code is %s", code) 31 | trace("desired indentation is %s", current) 32 | trace("current indentation is %s", oldIndentation) 33 | Layout(code.replace("\n"+ oldIndentation, "\n"+ current)) 34 | } 35 | } 36 | 37 | def indentationString(tree: scala.tools.nsc.Global#Tree): String = { 38 | 39 | def stripCommentFromSourceFile() = { 40 | if(memoizedSourceWithoutComments contains tree.pos.source.path) { 41 | memoizedSourceWithoutComments(tree.pos.source.path) 42 | } else { 43 | val src = SourceUtils.stripComment(tree.pos.source.content) 44 | memoizedSourceWithoutComments += tree.pos.source.path → src 45 | src 46 | } 47 | } 48 | 49 | var i = { 50 | if(tree.pos.start == tree.pos.source.length || tree.pos.source.content(tree.pos.start) == '\n' || tree.pos.source.content(tree.pos.start) == '\r') 51 | tree.pos.start - 1 52 | else 53 | tree.pos.start 54 | } 55 | val contentWithoutComment = stripCommentFromSourceFile() 56 | 57 | while(i >= 0 && contentWithoutComment(i) != '\n') { 58 | i -= 1 59 | } 60 | 61 | i += 1 62 | 63 | """\s*""".r.findFirstIn(contentWithoutComment.slice(i, tree.pos.start).mkString).getOrElse("") 64 | } 65 | 66 | private [this] val memoizedSourceWithoutComments = scala.collection.mutable.Map.empty[String, String] 67 | 68 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/sourcegen/Layout.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package sourcegen 7 | 8 | import scala.reflect.internal.util.SourceFile 9 | 10 | trait Layout { 11 | self => 12 | 13 | def contains(s: String) = withoutComments.contains(s) 14 | 15 | def matches(r: String) = withoutComments.matches(r) 16 | 17 | /** 18 | * @return Returns this layout as a string but without comments. 19 | * Comments are replaced by whitespace. 20 | */ 21 | lazy val withoutComments = SourceUtils.stripComment(asText) 22 | 23 | def asText: String 24 | 25 | override def toString() = asText 26 | 27 | def ++ (o: Layout) = o match { 28 | case NoLayout => this 29 | case _ => new Layout { 30 | override def asText = self.asText + o.asText 31 | } 32 | } 33 | 34 | def ++ (o: Fragment): Fragment = new Fragment { 35 | val leading = o.pre(self, o.leading) 36 | val center = o.center 37 | val trailing = o.trailing 38 | 39 | override val pre = if (o.pre.isRequired(this.leading, NoLayout)) o.pre else NoRequisite 40 | override val post = o.post 41 | } 42 | 43 | def ++ (r: Requisite): Fragment = new EmptyFragment { 44 | override val trailing = self 45 | override val post = r 46 | } 47 | 48 | def isEmpty: Boolean = asText.isEmpty 49 | def nonEmpty: Boolean = !isEmpty 50 | } 51 | 52 | case object NoLayout extends Layout { 53 | val asText = "" 54 | } 55 | 56 | object Layout { 57 | 58 | case class LayoutFromFile(source: SourceFile, start: Int, end: Int) extends Layout { 59 | 60 | lazy val asText = source.content.slice(start, end).mkString 61 | 62 | def splitAfter(cs: Char*): (Layout, Layout) = splitFromLeft(cs) match { 63 | case None => this → NoLayout 64 | case Some(i) => copy(end = i+1) → copy(start = i+1) 65 | } 66 | 67 | def splitAfterLast(cs: Char*): (Layout, Layout) = splitFromRight(cs) match { 68 | case None => this → NoLayout 69 | case Some(i) => copy(end = i+1) → copy(start = i+1) 70 | } 71 | 72 | def splitAtAndExclude(cs: Char*): (Layout, Layout) = splitFromLeft(cs) match { 73 | case None => this → NoLayout 74 | case Some(i) => copy(end = i) → copy(start = i+1) 75 | } 76 | 77 | def splitBefore(cs: Char*): (Layout, Layout) = splitFromLeft(cs) match { 78 | case None => NoLayout → this 79 | case Some(i) => copy(end = i) → copy(start = i) 80 | } 81 | 82 | private def splitFromLeft(cs: Seq[Char]): Option[Int] = { 83 | split(cs, c => withoutComments.indexOf(c.toInt)) 84 | } 85 | 86 | private def splitFromRight(cs: Seq[Char]): Option[Int] = { 87 | split(cs, c => withoutComments.lastIndexOf(c.toInt)) 88 | } 89 | 90 | private def split(cs: Seq[Char], findIndex: Char => Int): Option[Int] = cs.toList match { 91 | case Nil => 92 | None 93 | case x :: xs => 94 | val i = findIndex(x) 95 | if(i >= 0 ) { 96 | Some(start + i) 97 | } else 98 | split(xs, findIndex) 99 | } 100 | } 101 | 102 | def apply(source: SourceFile, start: Int, end: Int) = LayoutFromFile(source, start, end) 103 | 104 | def apply(s: String) = new Layout { 105 | val asText = s 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/sourcegen/Requisite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package sourcegen 7 | 8 | trait Requisite { 9 | self => 10 | 11 | def isRequired(l: Layout, r: Layout): Boolean 12 | 13 | def apply(l: Layout, r: Layout): Layout = { 14 | if(isRequired(l, r)) { 15 | insertBetween(l, r) 16 | } else { 17 | l ++ r 18 | } 19 | } 20 | 21 | protected def insertBetween(l: Layout, r: Layout) = l ++ getLayout ++ r 22 | 23 | def getLayout: Layout 24 | 25 | def ++(other: Requisite): Requisite = (self, other) match { 26 | case (r, NoRequisite) => r 27 | case (NoRequisite, r) => r 28 | case _ => new Requisite { 29 | def isRequired(l: Layout, r: Layout) = self.isRequired(l, r) || other.isRequired(l, r) 30 | def getLayout = self.getLayout ++ other.getLayout 31 | override def apply(l: Layout, r: Layout) = { 32 | val _1 = if(self.isRequired(l, r)) self.getLayout else NoLayout 33 | val _2 = if(other.isRequired(l, r)) other.getLayout else NoLayout 34 | l ++ _1 ++ _2 ++ r 35 | } 36 | } 37 | } 38 | } 39 | 40 | object Requisite { 41 | 42 | def allowSurroundingWhitespace(req: String, toPrint: String): Requisite = { 43 | 44 | val regexSafeString = req flatMap { 45 | case '(' => "\\(" 46 | case ')' => "\\)" 47 | case '{' => "\\{" 48 | case '}' => "\\}" 49 | case '[' => "\\[" 50 | case ']' => "\\]" 51 | case c => c.toString 52 | } 53 | 54 | new Requisite { 55 | def isRequired(l: Layout, r: Layout) = { 56 | val isInLeft = l.matches("(?ms).*\\s*"+ regexSafeString +"\\s*$") 57 | val isInRight = r.matches("(?ms)^\\s*"+ regexSafeString + ".*") 58 | !isInLeft && !isInRight 59 | } 60 | def getLayout = Layout(toPrint) 61 | } 62 | } 63 | 64 | def allowSurroundingWhitespace(str: String): Requisite = { 65 | allowSurroundingWhitespace(str, str) 66 | } 67 | 68 | def anywhere(s: String): Requisite = anywhere(s, s) 69 | 70 | def anywhere(req: String, print: String): Requisite = new Requisite { 71 | def isRequired(l: Layout, r: Layout) = { 72 | !(l.contains(req) || r.contains(req)) 73 | } 74 | def getLayout = Layout(print) 75 | } 76 | 77 | val Blank = new Requisite { 78 | def isRequired(l: Layout, r: Layout) = { 79 | val _1 = l.matches("(?s).*\\s+$") 80 | val _2 = r.matches("(?s)^\\s+.*") 81 | 82 | !(_1 || _2) 83 | } 84 | val getLayout = Layout(" ") 85 | } 86 | 87 | def newline(indentation: String, nl: String, force: Boolean = false) = new Requisite { 88 | def isRequired(l: Layout, r: Layout) = { 89 | val _1 = l.matches("(?ms).*\r?\n\\s*$") 90 | val _2 = r.matches("(?ms)^\\s*\r?\n.*") 91 | !(_1 || _2) 92 | } 93 | def getLayout = Layout(nl+ indentation) 94 | override def insertBetween(l: Layout, r: Layout) = { 95 | if(!force && r.asText.startsWith(indentation)) { 96 | l ++ Layout(nl) ++ r 97 | } else { 98 | l ++ getLayout ++ r 99 | } 100 | } 101 | } 102 | } 103 | 104 | object NoRequisite extends Requisite { 105 | def isRequired(l: Layout, r: Layout) = false 106 | val getLayout = NoLayout 107 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/transformation/TransformableSelections.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.transformation 2 | 3 | import scala.tools.refactoring.common.Selections 4 | import scala.tools.refactoring.common.CompilerAccess 5 | 6 | trait TransformableSelections extends Selections with TreeTransformations { 7 | self: CompilerAccess => 8 | 9 | import global._ 10 | 11 | implicit class TransformableSelection(selection: Selection) { 12 | def descendToEnclosingTreeAndThen(trans: Transformation[Tree, Tree]) = 13 | topdown { 14 | matchingChildren { 15 | predicate { (t: Tree) => 16 | t.samePosAndType(selection.enclosingTree) 17 | } &> trans 18 | } 19 | } 20 | 21 | private def replaceSingleStatementBy(replacement: Tree) = { 22 | val original = selection.selectedTopLevelTrees.head 23 | transform { 24 | case t if t.samePosAndType(original) => 25 | replacement replaces t 26 | case t => 27 | t 28 | } 29 | } 30 | 31 | private def replaceSequenceBy(replacement: Tree, preserveHierarchy: Boolean) = { 32 | transform { 33 | case block @ Block(stats, expr) => 34 | val allStats = (stats :+ expr) 35 | if (allStats.length == selection.selectedTopLevelTrees.length && !preserveHierarchy) { 36 | // only replace whole block if allowed to modify tree hierarchy 37 | replacement replaces block 38 | } else { 39 | val newStats = allStats.replaceSequencePreservingPositions(selection.selectedTopLevelTrees, replacement :: Nil) 40 | mkBlock(newStats) replaces block 41 | } 42 | } 43 | } 44 | 45 | /** 46 | * Replaces the selection by `replacement`. 47 | * 48 | * @param replacement 49 | * @param preserveHierarchy whether the original tree hierarchy must be preserved or 50 | * could be reduced if possible. 51 | * E.g. a selection contains all trees of the enclosing block: 52 | * - with `preserveHierarchy = true` the block will be replaced by `replacement` 53 | * - with `preserveHierarchy = false` the block will remain with `replacement` 54 | * as its only child tree 55 | */ 56 | def replaceBy(replacement: Tree, preserveHierarchy: Boolean = false) = { 57 | descendToEnclosingTreeAndThen { 58 | if (selection.selectedTopLevelTrees.length == 1) 59 | replaceSingleStatementBy(replacement) 60 | else 61 | replaceSequenceBy(replacement, preserveHierarchy) 62 | } 63 | } 64 | } 65 | } -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/util/Memoized.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package util 7 | 8 | object Memoized { 9 | 10 | // Use this switch to temporarily turn of memoization for 11 | // debugging purposes: 12 | private val MemonizationEnabled = true 13 | 14 | /** 15 | * 16 | * Create a function that memoizes its results in a WeakHashMap. 17 | * 18 | * Note that memoization is tricky if the objects that 19 | * are memoized are mutable and the function being memoized 20 | * returns a value that somehow depends on this mutable property. 21 | * 22 | * For example, if we memoize a function that filters trees 23 | * based on their position, and later modify the position of 24 | * a tree, the memoized function will return the wrong value. 25 | * 26 | * So in order to make it safe, the mkKey function can be used 27 | * to provide a better key by including the mutable value. 28 | * 29 | * @param mkKey A function that creates the key. 30 | * @param toMem The function we want to memoize. 31 | */ 32 | def on[X, Y, Z](mkKey: X => Y)(toMem: X => Z): X => Z = { 33 | if (!MemonizationEnabled) { 34 | toMem 35 | } else { 36 | val cache = new java.util.WeakHashMap[Y, Z] 37 | 38 | (x: X) => { 39 | val k = mkKey(x) 40 | if(cache.containsKey(k)) { 41 | val n = cache.get(k) 42 | if(n == null) { 43 | toMem(x) 44 | } else { 45 | n 46 | } 47 | } else { 48 | val n = toMem(x) 49 | cache.put(k, n) 50 | n 51 | } 52 | } 53 | } 54 | } 55 | 56 | def apply[X, Z](toMem: X => Z): X => Z = { 57 | if (!MemonizationEnabled) { 58 | toMem 59 | } else { 60 | val cache = new java.util.WeakHashMap[X, Z] 61 | 62 | (x: X) => { 63 | if(cache.containsKey(x)) { 64 | val n = cache.get(x) 65 | if(n == null) { 66 | toMem(x) 67 | } else { 68 | n 69 | } 70 | } else { 71 | val n = toMem(x) 72 | cache.put(x, n) 73 | n 74 | } 75 | } 76 | } 77 | } 78 | } 79 | 80 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/util/SourceHelpers.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.util 2 | 3 | import scala.annotation.tailrec 4 | import scala.math.min 5 | import scala.reflect.api.Position 6 | import scala.reflect.internal.util.SourceFile 7 | import scala.reflect.internal.util.RangePosition 8 | 9 | object SourceHelpers { 10 | 11 | /** 12 | * Decides whether a selection lies within a given text 13 | * 14 | * This is best explained by a few examples (selections are indicated by `[]`): 15 | * 22 | */ 23 | def isRangeWithin(text: String, selection: SourceWithSelection): Boolean = { 24 | if (selection.length > text.length || selection.source.length < text.length) { 25 | false 26 | } else { 27 | val maxStepsBack = min(text.length - selection.length, selection.start) 28 | 29 | @tailrec 30 | def tryMatchText(stepsBack: Int = 0): Boolean = { 31 | if (stepsBack > maxStepsBack) { 32 | false 33 | } else { 34 | val start = selection.start - stepsBack 35 | 36 | @tailrec 37 | def matchSlice(i: Int = 0): Boolean = { 38 | if (i >= text.length) { 39 | true 40 | } else { 41 | if (text.charAt(i) != selection.source.charAt(start + i)) { 42 | false 43 | } else { 44 | matchSlice(i + 1) 45 | } 46 | } 47 | } 48 | 49 | if (start + text.length <= selection.source.length && matchSlice()) { 50 | true 51 | } else { 52 | tryMatchText(stepsBack + 1) 53 | } 54 | } 55 | } 56 | 57 | tryMatchText() 58 | } 59 | } 60 | 61 | def stringCoveredBy(pos: Position): Option[String] = { 62 | if (pos.isRange) Some(new String(pos.source.content.slice(pos.start, pos.end))) 63 | else None 64 | } 65 | 66 | def findComments(source: SourceFile, includeTrailingNewline: Boolean = true): List[RangePosition] = { 67 | import scala.tools.refactoring.util.SourceWithMarker 68 | import scala.tools.refactoring.util.SourceWithMarker.Movements 69 | import scala.tools.refactoring.util.SourceWithMarker.Movements.charToMovement 70 | 71 | val commentMvnt = { 72 | if (includeTrailingNewline) Movements.comment ~ '\r'.optional ~ '\n'.optional 73 | else Movements.comment 74 | } 75 | 76 | @tailrec 77 | def doWork(srcWithMarker: SourceWithMarker, acc: List[RangePosition] = Nil): List[RangePosition] = { 78 | if (srcWithMarker.isDepleted) { 79 | acc 80 | } else { 81 | srcWithMarker.applyMovement(commentMvnt) match { 82 | case Some(srcAfterComment) => 83 | val (commentStart, commentEnd) = (srcWithMarker.marker, srcAfterComment.marker) 84 | val commentRange = new RangePosition(source, commentStart, commentStart, commentEnd) 85 | doWork(srcAfterComment, commentRange :: acc) 86 | 87 | case None => 88 | doWork(srcWithMarker.stepForward, acc) 89 | } 90 | } 91 | } 92 | 93 | doWork(SourceWithMarker(source.content)).reverse 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/util/SourceWithSelection.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.util 2 | 3 | case class SourceWithSelection(source: IndexedSeq[Char], start: Int, end: Int) { 4 | require(start > -1 && end >= start) 5 | require(end <= source.length) 6 | 7 | def length = end - start 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/util/UnionFind.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.util 2 | 3 | import scala.collection.mutable.Map 4 | import scala.collection.mutable.HashMap 5 | 6 | /* 7 | * Implements a standard Union-Find (a.k.a Disjoint Set) data 8 | * structure with permissive behavior with respect to 9 | * non-existing elements in the structure (Unknown elements are 10 | * added as new elements when queried for). 11 | * 12 | * See Cormen, Thomas H.; Leiserson, Charles E.; Rivest, Ronald 13 | * L.; Stein, Clifford (2001), "Chapter 21: Data structures for 14 | * Disjoint Sets", Introduction to Algorithms (Second ed.), MIT 15 | * Press, pp. 498–524, ISBN 0-262-03293-7 16 | * 17 | * Amortized time for a sequence of m {union, find} operations 18 | * is O(m * InvAckermann(n)) where n is the number of elements 19 | * and InvAckermann is the inverse of the Ackermann function. 20 | * 21 | * Not thread-safe. 22 | */ 23 | 24 | class UnionFind[T]() { 25 | 26 | private val parent: Map[T, T] = new HashMap[T,T] { 27 | override def default(s: T) = { 28 | get(s) match { 29 | case Some(v) => v 30 | case None => put(s, s); s 31 | } 32 | } 33 | } 34 | 35 | private val rank: Map[T, Int] = new HashMap[T,Int] { 36 | override def default(s: T) = { 37 | get(s) match { 38 | case Some(v) => v 39 | case None => put(s, 1); 1 40 | } 41 | } 42 | } 43 | 44 | /** 45 | * Return the parent (representant) of the equivalence class. 46 | * Uses path compression. 47 | */ 48 | def find(s: T): T = { 49 | val ps = parent(s) 50 | if (ps == s) s else { 51 | val cs = find(ps) 52 | parent(s) = cs 53 | cs 54 | } 55 | } 56 | 57 | /** 58 | * Unify equivalence classes of elements. 59 | * Uses union by rank. 60 | */ 61 | def union(x: T, y: T): Unit = { 62 | val cx = find(x) 63 | val cy = find(y) 64 | if (cx != cy) { 65 | val rx = rank(x) 66 | val ry = rank(y) 67 | if (rx > ry) parent(cy) = cx 68 | else if (rx < ry) parent(cx) = cy 69 | else { 70 | rank(cx) += 1 71 | parent(cy) = cx 72 | } 73 | } 74 | } 75 | 76 | /** 77 | * Enumerates the equivalence class of element x 78 | */ 79 | def equivalenceClass(x: T): List[T] = { 80 | val px = parent(x) 81 | parent.keys filter (parent(_:T) == px) toList 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /src/main/scala/scala/tools/refactoring/util/UniqueNames.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.util 2 | 3 | import java.util.UUID 4 | 5 | object UniqueNames { 6 | def scalaFile(): String = { 7 | s"${basename()}.scala" 8 | } 9 | 10 | def srcDir(): String = { 11 | s"src-${uid()}" 12 | } 13 | 14 | def basename(): String = { 15 | uid() 16 | } 17 | 18 | def scalaPackage(): String = { 19 | uid() 20 | } 21 | 22 | private def uid(): String = { 23 | def longToName(l: Long) = { 24 | java.lang.Long.toString(l, Character.MAX_RADIX).replace("-", "_") 25 | } 26 | 27 | val rid = UUID.randomUUID() 28 | "uid" + longToName(rid.getLeastSignificantBits) + longToName(rid.getMostSignificantBits) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/test/java/scala/tools/refactoring/common/TracingHelpersTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.common 2 | 3 | import org.junit.Test 4 | import org.junit.Assert._ 5 | 6 | 7 | class TracingHelpersTest { 8 | import TracingHelpers._ 9 | 10 | @Test 11 | def compactifyWithShortMsgs(): Unit = { 12 | assertEquals("", compactify("")) 13 | assertEquals("xxx", compactify("xxx")) 14 | } 15 | 16 | @Test 17 | def compactifyWithMultilineMsg(): Unit = { 18 | val mlMsg = "1. Do\n2. Make\n3. Say\n4. Think" 19 | assertEquals("1. Do...(3 more lines ommitted)", compactify(mlMsg)) 20 | } 21 | 22 | @Test 23 | def compactifyWithVeryLongMsg(): Unit = { 24 | val looooongMsg = "Loooooooonger, then eeeeeeeeeeever before!!!!!!!!!!!!!!!!!" 25 | val compactified = compactify(looooongMsg) 26 | assertTrue(compactified.length < looooongMsg.length) 27 | assertTrue(compactified.startsWith("Loooooooonger")) 28 | } 29 | } -------------------------------------------------------------------------------- /src/test/java/scala/tools/refactoring/tests/util/ExceptionWrapper.java: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.util; 2 | 3 | import org.junit.rules.TestRule; 4 | import org.junit.runner.Description; 5 | import org.junit.runners.model.Statement; 6 | 7 | import scala.tools.nsc.util.FailedInterrupt; 8 | 9 | /** 10 | * In case an assertion error is caught and wrapped by another error type (as it 11 | * is the case for exceptions that are thrown on the compiler thread), we need 12 | * to manually unwrap them later, in order to let JUnit "see" them. 13 | */ 14 | public final class ExceptionWrapper implements TestRule { 15 | 16 | @Override 17 | public Statement apply(final Statement base, final Description description) { 18 | return new Statement() { 19 | @Override 20 | public void evaluate() throws Throwable { 21 | try { 22 | base.evaluate(); 23 | } catch (FailedInterrupt e) { 24 | // a FailedInterrupt is thrown when an exception occurs on the compiler thread 25 | throw e.getCause(); 26 | } 27 | } 28 | }; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/test/java/scala/tools/refactoring/tests/util/ScalaVersion.java: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.util; 2 | 3 | import java.lang.annotation.Retention; 4 | import java.lang.annotation.RetentionPolicy; 5 | 6 | @Retention(RetentionPolicy.RUNTIME) 7 | public @interface ScalaVersion { 8 | String matches() default ""; 9 | String doesNotMatch() default ""; 10 | } 11 | -------------------------------------------------------------------------------- /src/test/java/scala/tools/refactoring/tests/util/ScalaVersionTestRule.java: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.util; 2 | 3 | import org.junit.Assume; 4 | import org.junit.rules.MethodRule; 5 | import org.junit.runners.model.FrameworkMethod; 6 | import org.junit.runners.model.Statement; 7 | 8 | import scala.util.Properties; 9 | 10 | public class ScalaVersionTestRule implements MethodRule { 11 | 12 | final class EmptyStatement extends Statement { 13 | @Override 14 | public void evaluate() throws Throwable { 15 | Assume.assumeTrue(false); 16 | } 17 | } 18 | 19 | public Statement apply(Statement stmt, FrameworkMethod meth, Object arg2) { 20 | ScalaVersion onlyOn = meth.getAnnotation(ScalaVersion.class); 21 | String versionString = Properties.versionString(); 22 | 23 | if (onlyOn != null) { 24 | if (!onlyOn.doesNotMatch().isEmpty() && versionString.contains(onlyOn.doesNotMatch())) { 25 | return new EmptyStatement(); 26 | } else if (versionString.contains(onlyOn.matches())) { 27 | return stmt; 28 | } else { 29 | return new EmptyStatement(); 30 | } 31 | } else { 32 | return stmt; 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/test/java/scala/tools/refactoring/tests/util/TestRules.java: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.util; 2 | 3 | import org.junit.Rule; 4 | 5 | public abstract class TestRules { 6 | 7 | // all rules need to be public fields 8 | 9 | @Rule 10 | public final ScalaVersionTestRule rule1 = new ScalaVersionTestRule(); 11 | 12 | @Rule 13 | public final ExceptionWrapper rule2 = new ExceptionWrapper(); 14 | } 15 | -------------------------------------------------------------------------------- /src/test/scala-2.10/README: -------------------------------------------------------------------------------- 1 | Put tests specifically for Scala-2.10.x here -------------------------------------------------------------------------------- /src/test/scala-2.10/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsScalaSpecificTests.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.implementations.imports 2 | 3 | import scala.tools.refactoring.implementations.OrganizeImports 4 | import scala.tools.refactoring.implementations.OrganizeImports.Dependencies 5 | 6 | class OrganizeImportsScalaSpecificTests extends OrganizeImportsBaseTest { 7 | private def organizeCustomized( 8 | groupPkgs: List[String] = List("java", "scala", "org", "com"), 9 | useWildcards: Set[String] = Set("scalaz", "scalaz.Scalaz"), 10 | dependencies: Dependencies.Value, 11 | organizeLocalImports: Boolean = true)(pro: FileSet) = new OrganizeImportsRefatoring(pro) { 12 | val oiConfig = OrganizeImports.OrganizeImportsConfig( 13 | importsStrategy = Some(OrganizeImports.ImportsStrategy.ExpandImports), 14 | wildcards = useWildcards, 15 | groups = groupPkgs) 16 | val params = { 17 | new refactoring.RefactoringParameters( 18 | deps = dependencies, 19 | organizeLocalImports = organizeLocalImports, 20 | config = Some(oiConfig)) 21 | } 22 | }.mkChanges 23 | 24 | @Ignore("Passes for scala 2.11 only. Implementation bases on 2.11 specific tree structure.") 25 | @Test 26 | def shouldNotRemoveImportWhenExtendedClassHasInferredTypeParam() = new FileSet { 27 | """ 28 | package acme 29 | 30 | class Extended[T](val a: T, t: String) 31 | """ isNotModified 32 | 33 | """ 34 | /*<-*/ 35 | package acme.test 36 | 37 | import acme.Extended 38 | 39 | class Tested(val id: Int) extends Extended(id, "text") 40 | """ isNotModified 41 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify) 42 | } 43 | -------------------------------------------------------------------------------- /src/test/scala-2.10/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsWithMacrosTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.implementations.imports 2 | 3 | /** Refactoring for 2.10 does not support macros. Left unimplemented. */ 4 | class OrganizeImportsWithMacrosTest extends OrganizeImportsBaseTest 5 | -------------------------------------------------------------------------------- /src/test/scala-2.11/README: -------------------------------------------------------------------------------- 1 | Put tests specifically for Scala-2.11.x here -------------------------------------------------------------------------------- /src/test/scala-2.11/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsScalaSpecificTests.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.implementations.imports 2 | 3 | import scala.tools.refactoring.implementations.OrganizeImports 4 | import scala.tools.refactoring.implementations.OrganizeImports.Dependencies 5 | 6 | class OrganizeImportsScalaSpecificTests extends OrganizeImportsBaseTest { 7 | private def organizeCustomized( 8 | groupPkgs: List[String] = List("java", "scala", "org", "com"), 9 | useWildcards: Set[String] = Set("scalaz", "scalaz.Scalaz"), 10 | dependencies: Dependencies.Value, 11 | organizeLocalImports: Boolean = true)(pro: FileSet) = new OrganizeImportsRefatoring(pro) { 12 | val oiConfig = OrganizeImports.OrganizeImportsConfig( 13 | importsStrategy = Some(OrganizeImports.ImportsStrategy.ExpandImports), 14 | wildcards = useWildcards, 15 | groups = groupPkgs) 16 | val params = { 17 | new refactoring.RefactoringParameters( 18 | deps = dependencies, 19 | organizeLocalImports = organizeLocalImports, 20 | config = Some(oiConfig)) 21 | } 22 | }.mkChanges 23 | 24 | @Test 25 | def shouldNotRemoveImportWhenExtendedClassHasInferredTypeParam() = new FileSet { 26 | """ 27 | package acme 28 | 29 | class Extended[T](val a: T, t: String) 30 | """ isNotModified 31 | 32 | """ 33 | /*<-*/ 34 | package acme.test 35 | 36 | import acme.Extended 37 | 38 | class Tested(val id: Int) extends Extended(id, "text") 39 | """ isNotModified 40 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify) 41 | 42 | @Test 43 | def shouldNotRemoveImportWhenJustPackage_v3() = new FileSet { 44 | """ 45 | /*<-*/ 46 | package tested 47 | 48 | import scala.reflect.macros.whitebox 49 | 50 | class Tested(val c: whitebox.Context) 51 | """ isNotModified 52 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify) 53 | } 54 | -------------------------------------------------------------------------------- /src/test/scala-2.12/README: -------------------------------------------------------------------------------- 1 | Put tests specifically for Scala-2.11.x here -------------------------------------------------------------------------------- /src/test/scala-2.12/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsScalaSpecificTests.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.implementations.imports 2 | 3 | import scala.tools.refactoring.implementations.OrganizeImports 4 | import scala.tools.refactoring.implementations.OrganizeImports.Dependencies 5 | 6 | class OrganizeImportsScalaSpecificTests extends OrganizeImportsBaseTest { 7 | private def organizeCustomized( 8 | groupPkgs: List[String] = List("java", "scala", "org", "com"), 9 | useWildcards: Set[String] = Set("scalaz", "scalaz.Scalaz"), 10 | dependencies: Dependencies.Value, 11 | organizeLocalImports: Boolean = true)(pro: FileSet) = new OrganizeImportsRefatoring(pro) { 12 | val oiConfig = OrganizeImports.OrganizeImportsConfig( 13 | importsStrategy = Some(OrganizeImports.ImportsStrategy.ExpandImports), 14 | wildcards = useWildcards, 15 | groups = groupPkgs) 16 | val params = { 17 | new refactoring.RefactoringParameters( 18 | deps = dependencies, 19 | organizeLocalImports = organizeLocalImports, 20 | config = Some(oiConfig)) 21 | } 22 | }.mkChanges 23 | 24 | @Test 25 | def shouldNotRemoveImportWhenExtendedClassHasInferredTypeParam() = new FileSet { 26 | """ 27 | package acme 28 | 29 | class Extended[T](val a: T, t: String) 30 | """ isNotModified 31 | 32 | """ 33 | /*<-*/ 34 | package acme.test 35 | 36 | import acme.Extended 37 | 38 | class Tested(val id: Int) extends Extended(id, "text") 39 | """ isNotModified 40 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify) 41 | 42 | @Test 43 | def shouldNotRemoveImportWhenJustPackage_v3() = new FileSet { 44 | """ 45 | /*<-*/ 46 | package tested 47 | 48 | import scala.reflect.macros.whitebox 49 | 50 | class Tested(val c: whitebox.Context) 51 | """ isNotModified 52 | } applyRefactoring organizeCustomized(dependencies = Dependencies.RecomputeAndModify) 53 | } 54 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/implementations/OrganizeImportsAlgosTest.scala: -------------------------------------------------------------------------------- 1 | 2 | 3 | package scala.tools.refactoring.implementations 4 | 5 | import org.junit.Test 6 | import org.junit.Assert._ 7 | 8 | class OrganizeImportsAlgosTest { 9 | import OrganizeImports.Algos 10 | 11 | @Test 12 | def testGroupImportsWithTrivialExamples(): Unit = { 13 | testGroupImports(Nil, Nil, Nil) 14 | testGroupImports(List("org", "com"), Nil, Nil) 15 | testGroupImports(Nil, List("com.github.Lausbub"), List(List("com.github.Lausbub"))) 16 | } 17 | 18 | @Test 19 | def testGroupImportsWithSimpleExamples(): Unit = { 20 | testGroupImports( 21 | groups = List("org"), 22 | imports = List("org.junit.Test", "org.junit.Assert._", "language.postfixOps"), 23 | expected = List(List("org.junit.Test", "org.junit.Assert._"), List("language.postfixOps"))) 24 | 25 | testGroupImports( 26 | groups = List("java", "scala", "org"), 27 | imports = List("java.lang.String", "java.lang.Long", "org.junit.Before", "scala.Option.option2Iterable", "language.implicitConversions"), 28 | expected = List(List("java.lang.String", "java.lang.Long"), List("scala.Option.option2Iterable"), List("org.junit.Before"), List("language.implicitConversions"))) 29 | } 30 | 31 | @Test 32 | def testGroupImportsWithNastyExamples(): Unit = { 33 | testGroupImports( 34 | groups = List("a.c", "a"), 35 | imports = List("a.b.X", "a.c.Y"), 36 | expected = List(List("a.c.Y"), List("a.b.X"))) 37 | 38 | testGroupImports( 39 | groups = List("a", "a.c"), 40 | imports = List("a.b.X", "a.c.Y"), 41 | expected = List(List("a.b.X"), List("a.c.Y"))) 42 | 43 | testGroupImports( 44 | groups = List("a", "a.b", "ab"), 45 | imports = List("a.A", "a.b.AB", "ab.Ab1", "ab.Ab2", "abc.Abc1", "abc.Abc2"), 46 | expected = List(List("a.A"), List("a.b.AB"), List("ab.Ab1", "ab.Ab2"), List("abc.Abc1", "abc.Abc2"))) 47 | } 48 | 49 | @Test 50 | def testGroupImportsWithAlternatives(): Unit = { 51 | testGroupImports( 52 | groups = List("a.c", "a,b.c", "b"), 53 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X"), 54 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y"))) 55 | 56 | testGroupImports( 57 | groups = List("a.c", "nonexisting", "a,b.c", "b"), 58 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X"), 59 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y"))) 60 | 61 | testGroupImports( 62 | groups = List("a.c", "*", "a,b.c", "b"), 63 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X", "d.a.Y"), 64 | expected = List(List("a.c.Y"), List("d.a.Y"), List("a.b.X", "b.c.X"), List("b.a.Y"))) 65 | 66 | // default group at the end 67 | testGroupImports( 68 | groups = List("a.c", "a,b.c", "b"), 69 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X", "d.a.Y"), 70 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y"), List("d.a.Y"))) 71 | 72 | // disordered alternative 73 | testGroupImports( 74 | groups = List("a.c", "b.c,a", "b"), 75 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X"), 76 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y"))) 77 | 78 | // disordered and repeated alternative 79 | testGroupImports( 80 | groups = List("a.c", "b.c,a", "b", "a,b.c"), 81 | imports = List("a.b.X", "a.c.Y", "b.a.Y", "b.c.X"), 82 | expected = List(List("a.c.Y"), List("a.b.X", "b.c.X"), List("b.a.Y"))) 83 | } 84 | 85 | private def testGroupImports(groups: List[String], imports: List[String], expected: List[List[String]]): Unit = { 86 | def getImportExpr(imp: String) = { 87 | val lastDot = imp.lastIndexOf('.') 88 | assert(lastDot >= 0) 89 | imp.substring(0, lastDot) 90 | } 91 | 92 | val actual = Algos.groupImports(getImportExpr)(groups, imports) 93 | assertEquals(expected, actual) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/RefactoringTestSuite.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package tests 3 | 4 | import org.junit.runner.RunWith 5 | import org.junit.runners.Suite 6 | import analysis._ 7 | import common._ 8 | import implementations._ 9 | import implementations.imports._ 10 | import sourcegen._ 11 | import transformation._ 12 | import util._ 13 | import scala.tools.refactoring.implementations.OrganizeImportsAlgosTest 14 | import scala.tools.refactoring.common.TracingHelpersTest 15 | 16 | @RunWith(value = classOf[Suite]) 17 | @Suite.SuiteClasses(value = Array( 18 | classOf[AddFieldTest], 19 | classOf[AddImportStatementTest], 20 | classOf[AddMethodTest], 21 | classOf[ChangeParamOrderTest], 22 | classOf[CompilationUnitDependenciesTest], 23 | classOf[CustomFormattingTest], 24 | classOf[DeclarationIndexTest], 25 | classOf[ExpandCaseClassBindingTest], 26 | classOf[ExplicitGettersSettersTest], 27 | classOf[extraction.ExtractCodeTest], 28 | classOf[extraction.ExtractExtractorTest], 29 | classOf[extraction.ExtractionsTest], 30 | classOf[extraction.ExtractMethodTest], 31 | classOf[extraction.ExtractParameterTest], 32 | classOf[extraction.ExtractValueTest], 33 | classOf[ExtractLocalTest], 34 | classOf[ExtractMethodTest], 35 | classOf[ExtractTraitTest], 36 | classOf[FindShadowedTest], 37 | classOf[GenerateHashcodeAndEqualsTest], 38 | classOf[ImportAnalysisTest], 39 | classOf[IndividualSourceGenTest], 40 | classOf[InlineLocalTest], 41 | classOf[InsertionPositionsTest], 42 | classOf[IntroduceProductNTraitTest], 43 | classOf[LayoutTest], 44 | classOf[MarkOccurrencesTest], 45 | classOf[MergeParameterListsTest], 46 | classOf[MoveClassTest], 47 | classOf[MoveConstructorToCompanionObjectTest], 48 | classOf[MultipleFilesIndexTest], 49 | classOf[NameValidationTest], 50 | classOf[OrganizeImportsCollapseSelectorsToWildcardTest], 51 | classOf[OrganizeImportsFullyRecomputeTest], 52 | classOf[OrganizeImportsGroupsTest], 53 | classOf[OrganizeImportsRecomputeAndModifyTest], 54 | classOf[OrganizeImportsTest], 55 | classOf[OrganizeImportsWildcardsTest], 56 | classOf[OrganizeMissingImportsTest], 57 | classOf[EnrichedTreesTest], 58 | classOf[PrependOrDropScalaPackageFromRecomputedTest], 59 | classOf[PrependOrDropScalaPackageKeepTest], 60 | classOf[PrettyPrinterTest], 61 | classOf[RenameTest], 62 | classOf[ReusingPrinterTest], 63 | classOf[ScopeAnalysisTest], 64 | classOf[SelectionDependenciesTest], 65 | classOf[SelectionExpansionsTest], 66 | classOf[SelectionPropertiesTest], 67 | classOf[SelectionsTest], 68 | classOf[SourceGenTest], 69 | classOf[SourceHelperTest], 70 | classOf[SplitParameterListsTest], 71 | classOf[TransformableSelectionTest], 72 | classOf[TreeAnalysisTest], 73 | classOf[TreeChangesDiscovererTest], 74 | classOf[TreeTransformationsTest], 75 | classOf[UnionFindInitTest], 76 | classOf[UnionFindTest], 77 | classOf[UnusedImportsFinderTest], 78 | classOf[SourceWithMarkerTest], 79 | classOf[OrganizeImportsAlgosTest], 80 | classOf[TracingHelpersTest], 81 | classOf[OrganizeImportsWithMacrosTest], 82 | classOf[OrganizeImportsScalaSpecificTests], 83 | classOf[OrganizeImportsEndOfLineTest])) 84 | class RefactoringTestSuite {} 85 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/analysis/FindShadowedTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.analysis 7 | 8 | import tests.util.TestHelper 9 | import org.junit.Assert._ 10 | 11 | class FindShadowedTest extends TestHelper { 12 | 13 | import global._ 14 | 15 | @Ignore 16 | @Test 17 | def findSimpleShadowing(): Unit = { 18 | 19 | val t = treeFrom(""" 20 | package shadowing 21 | 22 | object TheShadow { 23 | val i = 1 24 | 25 | def method: Unit = { 26 | val i = "" 27 | () 28 | } 29 | } 30 | 31 | class Xyz(xyzxyz: Long) { 32 | 33 | def method: Unit = { 34 | val xyzxyz = "" 35 | val i = xyzxyz 36 | () 37 | } 38 | } 39 | 40 | class Z { 41 | import TheShadow._ 42 | 43 | val i = Nil 44 | }""") 45 | 46 | val results = new collection.mutable.ListBuffer[Symbol] 47 | 48 | t foreach { 49 | case v @ ValDef(_, name, _, _) => 50 | 51 | val members = { 52 | 53 | def contexts(n: Context): Stream[Context] = n #:: contexts(n.outer) 54 | 55 | val context = global.doLocateContext(v.pos).outer 56 | 57 | val fromEnclosingScopes = contexts(context).takeWhile(_ != NoContext) flatMap { 58 | ctx => ctx.scope ++ (if (ctx == ctx.enclClass) ctx.prefix.members else Nil) 59 | } 60 | 61 | val fromImported = (context.imports flatMap (_.allImportedSymbols)) 62 | 63 | fromEnclosingScopes ++ fromImported 64 | } 65 | 66 | results ++= members.find(s => s.name == name && s.pos != v.symbol.pos) 67 | 68 | case _ => Nil 69 | } 70 | 71 | assertEquals(3, results.size) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/analysis/ImportAnalysisTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.analysis 2 | 3 | import scala.tools.refactoring.tests.util.TestHelper 4 | import scala.tools.refactoring.analysis.ImportAnalysis 5 | import org.junit.Assert._ 6 | 7 | class ImportAnalysisTest extends TestHelper with ImportAnalysis { 8 | import global._ 9 | 10 | @Test 11 | def importTrees() = global.ask { () => 12 | val tree = treeFrom(""" 13 | import scala.collection.immutable.{LinearSeq, _} 14 | 15 | object O{ 16 | import scala.math._ 17 | val a = 1 18 | } 19 | 20 | object P{ 21 | import collection.mutable._ 22 | import O.a 23 | } 24 | """) 25 | 26 | val it = buildImportTree(tree) 27 | assertEquals("{scala.collection.immutable.List{scala.`package`._{scala.Predef._{scala.collection.immutable.LinearSeq{scala.collection.immutable._{scala.math._{}, scala.collection.mutable._{O.a{}}}}}}}}", it.toString()) 28 | } 29 | 30 | @Test 31 | def isImported() = global.ask { () => 32 | val tree = treeFrom(""" 33 | import scala.math.E 34 | 35 | object O{ 36 | def fn = { 37 | import scala.math.Pi 38 | 99 * Pi * E 39 | } 40 | } 41 | """) 42 | 43 | val it = buildImportTree(tree) 44 | val piRef = findSymTree(tree, "value Pi") 45 | val eRef = findSymTree(tree, "value E") 46 | val fnDef = findSymTree(tree, "method fn") 47 | 48 | assertTrue(it.isImportedAt(piRef.symbol, piRef.pos)) 49 | assertTrue(it.isImportedAt(eRef.symbol, eRef.pos)) 50 | 51 | assertFalse(it.isImportedAt(piRef.symbol, fnDef.pos)) 52 | assertTrue(it.isImportedAt(eRef.symbol, fnDef.pos)) 53 | } 54 | 55 | @Test 56 | def isImportedWithWildcard() = global.ask { () => 57 | val tree = treeFrom(""" 58 | object O{ 59 | def fn = { 60 | import scala.math._ 61 | 99 * Pi * E 62 | } 63 | } 64 | """) 65 | 66 | val it = buildImportTree(tree) 67 | val piRef = findSymTree(tree, "value Pi") 68 | val eRef = findSymTree(tree, "value E") 69 | val fnDef = findSymTree(tree, "method fn") 70 | 71 | assertTrue(it.isImportedAt(piRef.symbol, piRef.pos)) 72 | assertTrue(it.isImportedAt(eRef.symbol, eRef.pos)) 73 | 74 | assertFalse(it.isImportedAt(piRef.symbol, fnDef.pos)) 75 | } 76 | 77 | @Test 78 | def predefsAreAlwaysImported() = global.ask { () => 79 | val tree = treeFrom(""" 80 | object O{ 81 | println(123) 82 | List(1, 2, 3) 83 | } 84 | """) 85 | 86 | val it = buildImportTree(tree) 87 | val printlnRef = findSymTree(tree, "method println") 88 | val listRef = findSymTree(tree, "object List") 89 | 90 | val oDef = findSymTree(tree, "object O") 91 | 92 | assertTrue(it.isImportedAt(printlnRef.symbol, oDef.pos)) 93 | assertTrue(it.isImportedAt(listRef.symbol, oDef.pos)) 94 | } 95 | 96 | @Test 97 | @Ignore 98 | def importsOfValueMembers() = global.ask { () => 99 | val tree = treeFrom(""" 100 | package pkg 101 | object O{ 102 | val a = new{ val b = 1 } 103 | 104 | import a._ 105 | 106 | def fn = b 107 | } 108 | """) 109 | 110 | val it = buildImportTree(tree) 111 | val bRef = findSymTree(findSymTree(tree, "method fn"), "value b") 112 | 113 | val oDef = findSymTree(tree, "object O") 114 | val aDef = findSymTree(tree, "value a") 115 | 116 | assertFalse(it.isImportedAt(bRef.symbol, oDef.pos)) 117 | assertFalse(it.isImportedAt(bRef.symbol, aDef.pos)) 118 | 119 | assertTrue(it.isImportedAt(bRef.symbol, bRef.pos)) 120 | } 121 | 122 | def findSymTree(t: Tree,s: String) = 123 | t.collect{ 124 | case t: SymTree if t.symbol.toString == s => t 125 | }.head 126 | } 127 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/analysis/TreeAnalysisTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.analysis 7 | 8 | import tests.util.TestHelper 9 | import org.junit.Assert._ 10 | import analysis.GlobalIndexes 11 | import analysis.TreeAnalysis 12 | 13 | class TreeAnalysisTest extends TestHelper with GlobalIndexes with TreeAnalysis { 14 | 15 | import global._ 16 | 17 | var index: IndexLookup = null 18 | 19 | def withIndex(src: String)(body: Tree => Unit ): Unit = { 20 | val tree = treeFrom(src) 21 | global.ask { () => 22 | index = GlobalIndex(List(CompilationUnitIndex(tree))) 23 | } 24 | body(tree) 25 | } 26 | 27 | def assertInboundLocalDependencies(expected: String, src: String) = withIndex(src) { tree => 28 | 29 | val selection = findMarkedNodes(src, tree) 30 | val in = global.ask(() => inboundLocalDependencies(selection, selection.selectedSymbols.head.owner)) 31 | assertEquals(expected, in mkString ", ") 32 | } 33 | 34 | def assertOutboundLocalDependencies(expected: String, src: String) = withIndex(src) { tree => 35 | 36 | val selection = findMarkedNodes(src, tree) 37 | val out = global.ask(() => outboundLocalDependencies(selection)) 38 | assertEquals(expected, out mkString ", ") 39 | } 40 | 41 | @Test 42 | def findInboudLocalAndParameter() = { 43 | 44 | assertInboundLocalDependencies("value i, value a", """ 45 | class A9 { 46 | def addThree(i: Int) = { 47 | val a = 1 48 | /*(*/ val b = a + 1 + i /*)*/ 49 | val c = b + 1 50 | c 51 | } 52 | } 53 | """) 54 | } 55 | 56 | @Test 57 | def findParameterDependency() = { 58 | 59 | assertInboundLocalDependencies("value i", """ 60 | class A8 { 61 | def addThree(i: Int) = { 62 | val a = 1 63 | /*(*/ val b = for(x <- 0 to i) yield x /*)*/ 64 | "done" 65 | } 66 | } 67 | """) 68 | } 69 | 70 | @Test 71 | def findNoDependency() = { 72 | 73 | assertInboundLocalDependencies("", """ 74 | class A7 { 75 | def addThree(i: Int) = { 76 | val a = 1 77 | /*(*/ val b = 2 * 21 /*)*/ 78 | b 79 | } 80 | } 81 | """) 82 | } 83 | 84 | @Test 85 | def findDependencyOnMethod() = { 86 | 87 | assertInboundLocalDependencies("value i, method inc", """ 88 | class A6 { 89 | def addThree(i: Int) = { 90 | def inc(j: Int) = j + 1 91 | /*(*/ val b = inc(inc(inc(i))) /*)*/ 92 | b 93 | } 94 | } 95 | """) 96 | } 97 | 98 | @Test 99 | def findOutboundDeclarations() = { 100 | 101 | assertOutboundLocalDependencies("value b", """ 102 | class A5 { 103 | def addThree = { 104 | /*(*/ val b = 1 /*)*/ 105 | b + b + b 106 | } 107 | } 108 | """) 109 | } 110 | 111 | @Test 112 | def multipleReturnValues() = { 113 | 114 | assertOutboundLocalDependencies("value a, value b, value c", """ 115 | class TreeAnalysisTest { 116 | def addThree = { 117 | /*(*/ val a = 'a' 118 | val b = 'b' 119 | val c = 'c'/*)*/ 120 | a + b + c 121 | } 122 | } 123 | """) 124 | } 125 | 126 | @Test 127 | def dontReturnArgument() = { 128 | 129 | assertOutboundLocalDependencies("", """ 130 | class TreeAnalysisTest { 131 | def go = { 132 | var a = 1 133 | /*(*/ a = 2 /*)*/ 134 | a 135 | } 136 | } 137 | """) 138 | } 139 | 140 | @Test 141 | def findOnClassLevel() = { 142 | 143 | assertInboundLocalDependencies("", """ 144 | class Outer { 145 | class B2 { 146 | val a = 1 147 | /*(*/ val b = a + 1 /*)*/ 148 | 149 | def addThree(i: Int) = { 150 | val a = 1 151 | val b = 2 * 21 152 | b 153 | } 154 | } 155 | } 156 | """) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/common/EnrichedTreesTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.common 7 | 8 | import tests.util.TestHelper 9 | import org.junit.Assert._ 10 | import common.EnrichedTrees 11 | 12 | class EnrichedTreesTest extends TestHelper with EnrichedTrees { 13 | 14 | import global._ 15 | 16 | def tree = treeFrom(""" 17 | package treetest 18 | 19 | class Test { 20 | val test = 42 21 | val test2 = 42 22 | } 23 | 24 | """) 25 | 26 | @Test 27 | def classHasNoRightSibling() = global.ask { () => 28 | 29 | val c = tree.find(_.isInstanceOf[ClassDef]).get 30 | 31 | assertFalse(originalRightSibling(c).isDefined) 32 | assertTrue(originalLeftSibling(c).isDefined) 33 | } 34 | 35 | @Test 36 | def templateNoSiblings() = global.ask { () => 37 | 38 | val c = tree.find(_.isInstanceOf[Template]).get 39 | 40 | assertTrue(originalLeftSibling(c).isDefined) 41 | assertFalse(originalRightSibling(c).isDefined) 42 | } 43 | 44 | @Test 45 | def parentChain() = global.ask { () => 46 | 47 | val i = tree.find(_.toString == "42").get 48 | 49 | val root = originalParentOf(i) flatMap (originalParentOf(_) flatMap (originalParentOf(_) flatMap originalParentOf)) 50 | 51 | assertTrue(root.get.isInstanceOf[PackageDef]) 52 | } 53 | 54 | @Test 55 | def rootHasNoParent() = global.ask { () => 56 | assertEquals(None, originalParentOf(tree)) 57 | } 58 | 59 | @Test 60 | def testSiblings() = global.ask { () => 61 | 62 | val v = tree.find(_.isInstanceOf[ValDef]).get 63 | val actual = originalParentOf(v).get.toString.replaceAll("\r\n", "\n") 64 | 65 | assertTrue(actual.contains("private[this] val test: Int = 42;")) 66 | 67 | assertEquals(None, originalLeftSibling(v)) 68 | assertEquals("private[this] val test2: Int = 42", originalRightSibling(v).get.toString) 69 | } 70 | 71 | @Test 72 | def namePositionOfFieldAccessor() = { 73 | val src = """ 74 | object O{ 75 | val field = 1 76 | } 77 | """ 78 | val root = treeFrom(src) 79 | 80 | val accessor = root.collect{ 81 | case t: ValOrDefDef if t.name.decode == "field" => t 82 | }.head 83 | 84 | assertEquals(src.indexOf("field"), accessor.namePosition().point) 85 | } 86 | } 87 | 88 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/common/OccurrencesTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.common 2 | 3 | import scala.tools.refactoring.tests.util.TestHelper 4 | import scala.tools.refactoring.common.Occurrences 5 | import org.junit.Assert._ 6 | import scala.tools.refactoring.analysis.GlobalIndexes 7 | 8 | class OccurrencesTest extends TestHelper with GlobalIndexes with Occurrences { 9 | import global._ 10 | 11 | var index: IndexLookup = null 12 | 13 | def withIndex(src: String)(body: Tree => Unit): Unit = { 14 | val tree = treeFrom(src) 15 | global.ask { () => 16 | index = GlobalIndex(List(CompilationUnitIndex(tree))) 17 | } 18 | body(tree) 19 | } 20 | 21 | @Test 22 | def termOccurrences() = { 23 | val src = """ 24 | object O{ 25 | def fn = { 26 | val a = 1 27 | val b = { 28 | val a = 2 29 | a 30 | } 31 | a * a 32 | } 33 | } 34 | """ 35 | withIndex(src) { root => 36 | val os = global.ask { () => termNameOccurrences(root, "a") } 37 | assertEquals(src.indexOf("a = 1"), os.head._1) 38 | assertTrue(os.forall(_._2 == "a".length)) 39 | assertEquals(3, os.length) 40 | } 41 | } 42 | 43 | @Test 44 | def accessorNameOccurrences() = { 45 | val src = """ 46 | object O{ 47 | val field = 1 48 | 49 | def fn = 2 * field 50 | } 51 | """ 52 | withIndex(src) { root => 53 | val os = global.ask { () => termNameOccurrences(root, "field") } 54 | assertEquals(src.indexOf("field"), os.head._1) 55 | assertTrue(os.forall(_._2 == "field".length)) 56 | assertEquals(2, os.length) 57 | } 58 | } 59 | 60 | @Test 61 | def paramsOccurrences() = { 62 | val src = """ 63 | object O{ 64 | def fn(a: Int, b: Int) = { 65 | val c = a * b 66 | a * b * c 67 | } 68 | } 69 | """ 70 | withIndex(src) { root => 71 | val os = global.ask { () => defDefParameterOccurrences(root, "fn") } 72 | val aos = os(0) 73 | val bos = os(1) 74 | assertEquals(src.indexOf("a: Int"), aos.head._1) 75 | assertEquals(src.indexOf("b: Int"), bos.head._1) 76 | assertTrue((aos ::: bos).forall(_._2 == 1)) 77 | // two params 78 | assertEquals(2, os.length) 79 | // each name occurs three times 80 | assertEquals(aos.length, 3) 81 | assertEquals(bos.length, 3) 82 | } 83 | } 84 | } -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/common/SelectionPropertiesTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.common 2 | 3 | import scala.tools.refactoring.common.Selections 4 | import scala.tools.refactoring.tests.util.TestHelper 5 | 6 | import org.junit.Assert._ 7 | import scala.tools.refactoring.tests.util.TextSelections 8 | 9 | class SelectionPropertiesTest extends TestHelper with Selections { 10 | 11 | implicit class StringToSel(src: String) { 12 | val root = treeFrom(src) 13 | val selection = { 14 | val textSelection = TextSelections.extractOne(src) 15 | FileSelection(root.pos.source.file, root, textSelection.from, textSelection.to) 16 | } 17 | } 18 | 19 | @Test 20 | def representsValue() = global.ask { () => 21 | val sel = """ 22 | object O{ 23 | def fn = { 24 | /*(*/val i = 100 25 | i * 2/*)*/ 26 | } 27 | } 28 | """.selection 29 | assertTrue(sel.representsValue) 30 | } 31 | 32 | @Test 33 | def doesNotRepresentValue() = global.ask { () => 34 | val sel = """ 35 | object O{ 36 | def fn = { 37 | /*(*/val i = 100 38 | val b = i * 2/*)*/ 39 | } 40 | } 41 | """.selection 42 | assertFalse(sel.representsValue) 43 | } 44 | 45 | @Test 46 | def nonValuePatternsDoNotRepresentValues() = global.ask { () => 47 | val selWildcard = """object O { 1 match { case /*(*/_/*)*/ => () } }""".selection 48 | assertFalse(selWildcard.representsValue) 49 | 50 | val selCtorPattern = """object O { Some(1) match { case /*(*/Some(i)/*)*/ => () } }""".selection 51 | assertFalse(selCtorPattern.representsValue) 52 | 53 | val selBinding = """object O { 1 match { case /*(*/i: Int/*)*/ => i } }""".selection 54 | assertFalse(selBinding.representsValue) 55 | 56 | val selPatAndGuad = """object O { 1 match { case /*(*/i if i > 10/*)*/ => i } }""".selection 57 | assertFalse(selPatAndGuad.representsValue) 58 | } 59 | 60 | @Test 61 | def valuePatternsDoRepresentValues() = global.ask { () => 62 | val selCtorPattern = """object O { Some(1) match { case /*(*/Some(1)/*)*/ => () } }""".selection 63 | assertTrue(selCtorPattern.representsValue) 64 | } 65 | 66 | @Test 67 | def argumentLists() = global.ask { () => 68 | val sel = """ 69 | object O{ 70 | def fn = { 71 | List(/*(*/1, 2/*)*/, 3) 72 | } 73 | } 74 | """.selection 75 | assertFalse(sel.representsValue) 76 | assertFalse(sel.representsValueDefinitions) 77 | assertTrue(sel.representsArgument) 78 | } 79 | 80 | @Test 81 | def parameter() = global.ask { () => 82 | val sel = """ 83 | object O{ 84 | def fn(/*(*/a: Int/*)*/) = { 85 | a 86 | } 87 | } 88 | """.selection 89 | assertFalse(sel.representsValue) 90 | assertTrue(sel.representsValueDefinitions) 91 | assertTrue(sel.representsParameter) 92 | } 93 | 94 | @Test 95 | def multipleParameters() = global.ask { () => 96 | val sel = """ 97 | object O{ 98 | def fn(/*(*/a: Int, b: Int/*)*/) = { 99 | a * b 100 | } 101 | } 102 | """.selection 103 | assertFalse(sel.representsValue) 104 | assertTrue(sel.representsValueDefinitions) 105 | assertTrue(sel.representsParameter) 106 | } 107 | 108 | @Test 109 | def triggersSideEffects() = global.ask { () => 110 | val sel = """ 111 | object O{ 112 | var a = 1 113 | /*(*/def fn = { 114 | a += 1 115 | a 116 | }/*)*/ 117 | } 118 | """.selection 119 | assertTrue(sel.mayHaveSideEffects) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/common/SelectionsTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.common 7 | 8 | import tests.util.TestHelper 9 | import org.junit.Assert._ 10 | import scala.tools.refactoring.tests.util.TextSelections 11 | 12 | class SelectionsTest extends TestHelper { 13 | 14 | private def getIndexedSelection(src: String) = { 15 | val tree = treeFrom(src) 16 | val textSelection = TextSelections.extractOne(src) 17 | FileSelection(tree.pos.source.file, tree, textSelection.from, textSelection.to) 18 | } 19 | 20 | def selectedLocalVariable(expected: String, src: String) = { 21 | 22 | val selection = getIndexedSelection(src) 23 | 24 | assertEquals(expected, selection.selectedSymbolTree.get.symbol.name.toString) 25 | } 26 | 27 | def assertSelection(expectedTrees: String, expectedSymbols: String, src: String) = { 28 | 29 | val selection = getIndexedSelection(src) 30 | 31 | assertEquals(expectedTrees, selection.allSelectedTrees map (_.getClass.getSimpleName) mkString ", ") 32 | assertEquals(expectedSymbols, selection.selectedSymbols mkString ", ") 33 | } 34 | 35 | @Test 36 | def findValDefInMethod() = { 37 | assertSelection( 38 | "ValDef, Apply, Select, Ident, Ident", 39 | "value b, method +, value a, value i", """ 40 | package findValDefInMethod 41 | class A { 42 | def addThree(i: Int) = { 43 | val a = 1 44 | /*(*/ val b = a + i /*)*/ 45 | val c = b + 1 46 | c 47 | } 48 | } 49 | """) 50 | } 51 | 52 | @Test 53 | def findIdentInMethod() = { 54 | assertSelection("Ident", "value i", """ 55 | package findIdentInMethod 56 | class A { 57 | def addThree(i: Int) = { 58 | val a = 1 59 | val b = a + /*(*/ i /*)*/ 60 | val c = b + 1 61 | c 62 | } 63 | } 64 | """) 65 | } 66 | 67 | @Test 68 | def findInMethodArguments() = { 69 | assertSelection("ValDef, TypeTree", "value i", """ 70 | package findInMethodArguments 71 | class A { 72 | def addThree(/*(*/ i : Int /*)*/) = { 73 | i 74 | } 75 | } 76 | """) 77 | } 78 | 79 | @Test 80 | def findWholeMethod() = { 81 | assertSelection( 82 | "DefDef, ValDef, TypeTree, Apply, Select, Ident, Literal", 83 | "method addThree, value i, method *, value i", """ 84 | package findWholeMethod 85 | class A { 86 | /*(*/ 87 | def addThree(i: Int) = { 88 | i * 5 89 | } 90 | /*)*/ 91 | } 92 | """) 93 | 94 | } 95 | @Test 96 | def findNothing() = { 97 | assertSelection("", "", """ 98 | package findNothing 99 | class A { 100 | /*(*/ /*)*/ 101 | def addThree(i: Int) = { 102 | i * 5 103 | } 104 | } 105 | """) 106 | } 107 | 108 | @Test 109 | def findSelectedLocal() = { 110 | selectedLocalVariable("copy", """ 111 | package findSelectedLocal 112 | class A { 113 | def times5(i: Int) = { 114 | val /*(*/copy/*)*/ = i 115 | copy * 5 116 | } 117 | } 118 | """) 119 | } 120 | 121 | @Test 122 | def selectedTheFirstCompleteSymbol() = { 123 | selectedLocalVariable("i", """ 124 | package selectedTheFirstCompleteSymbol 125 | class A { 126 | def times5(i: Int) = { 127 | val /*(*/copy = i /*)*/ 128 | copy * 5 129 | } 130 | } 131 | """) 132 | } 133 | 134 | @Test 135 | def selectedTheFirstSymbol() = { 136 | selectedLocalVariable("copy", """ 137 | package selectedTheFirstSymbol 138 | class A { 139 | def times5(i: Int) = { 140 | /*(*/ val copy = i /*)*/ 141 | copy * 5 142 | } 143 | } 144 | """) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/implementations/AddFieldTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package tests.implementations 3 | 4 | import common.Change 5 | import org.junit.Assert.assertEquals 6 | import tests.util.TestHelper 7 | import scala.tools.refactoring.implementations._ 8 | import language.reflectiveCalls 9 | import scala.tools.refactoring.util.UniqueNames 10 | 11 | class AddFieldTest extends TestHelper { 12 | outer => 13 | 14 | def addField(className: String, valName: String, isVar: Boolean, returnType: Option[String], target: AddMethodTarget, src: String, expected: String) = { 15 | global.ask { () => 16 | val refactoring = new AddField { 17 | val global = outer.global 18 | val file = addToCompiler(UniqueNames.basename(), src) 19 | val change = addField(file, className, valName, isVar, returnType, target) 20 | } 21 | assertEquals(expected, Change.applyChanges(refactoring.change, src)) 22 | } 23 | } 24 | 25 | @Test 26 | def addValToObject() = { 27 | addField("Main", "field", isVar = false, Option("Any"), AddToObject, """ 28 | class Main 29 | object Main { 30 | 31 | }""", 32 | """ 33 | class Main 34 | object Main { 35 | val field: Any = ??? 36 | }""") 37 | } 38 | 39 | @Test 40 | def addValToObject2() = { 41 | addField("Main", "*", isVar = false, None, AddToObject, """ 42 | class Main 43 | object Main { 44 | 45 | }""", 46 | """ 47 | class Main 48 | object Main { 49 | val * = ??? 50 | }""") 51 | } 52 | 53 | @Test 54 | def addVarToObject() = { 55 | addField("Main", "field", isVar = true, Option("Any"), AddToObject, """ 56 | class Main 57 | object Main { 58 | 59 | }""", 60 | """ 61 | class Main 62 | object Main { 63 | var field: Any = ??? 64 | }""") 65 | } 66 | 67 | @Test 68 | def addVarToObject2() = { 69 | addField("Main", "*", isVar = true, None, AddToObject, """ 70 | class Main 71 | object Main { 72 | 73 | }""", 74 | """ 75 | class Main 76 | object Main { 77 | var * = ??? 78 | }""") 79 | } 80 | 81 | @Test 82 | def addValToClass() = { 83 | addField("Main", "field", isVar = false, Option("Any"), AddToClass, """ 84 | object Main 85 | class Main { 86 | def existingMethod = "this is an existing method" 87 | }""", 88 | """ 89 | object Main 90 | class Main { 91 | def existingMethod = "this is an existing method" 92 | 93 | val field: Any = ??? 94 | }""") 95 | } 96 | 97 | @Test 98 | def addValToInnerClass() = { 99 | addField("Inner", "field", isVar = false, Option("Any"), AddToClass, """ 100 | class Main { 101 | class Inner 102 | }""", 103 | """ 104 | class Main { 105 | class Inner { 106 | val field: Any = ??? 107 | } 108 | }""") 109 | } 110 | 111 | @Test 112 | def addValToCaseClass() = { 113 | addField("Main", "field", isVar = false, Option("Any"), AddToClass, """ 114 | case class Main""", 115 | """ 116 | case class Main { 117 | val field: Any = ??? 118 | }""") 119 | } 120 | 121 | @Test 122 | def addValToTrait() = { 123 | addField("Main", "field", isVar = false, Option("Any"), AddToClass, """ 124 | trait Main""", 125 | """ 126 | trait Main { 127 | val field: Any = ??? 128 | }""") 129 | } 130 | 131 | @Test 132 | def addValByClosestPosition() = { 133 | addField("Main", "field", isVar = false, Option("Any"), AddToClosest(30), """ 134 | class Main { 135 | 136 | } 137 | object Main { 138 | 139 | }""", 140 | """ 141 | class Main { 142 | 143 | } 144 | object Main { 145 | val field: Any = ??? 146 | }""") 147 | } 148 | 149 | @Test 150 | def addVarByClosestPosition() = { 151 | addField("Main", "field", isVar = true, Option("Any"), AddToClosest(30), """ 152 | class Main { 153 | 154 | } 155 | object Main { 156 | 157 | }""", 158 | """ 159 | class Main { 160 | 161 | } 162 | object Main { 163 | var field: Any = ??? 164 | }""") 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/implementations/ExplicitGettersSettersTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.implementations 7 | 8 | import implementations.ExplicitGettersSetters 9 | import tests.util.TestHelper 10 | import tests.util.TestRefactoring 11 | 12 | 13 | class ExplicitGettersSettersTest extends TestHelper with TestRefactoring { 14 | outer => 15 | 16 | def explicitGettersSetters(pro: FileSet) = new TestRefactoringImpl(pro) { 17 | val refactoring = new ExplicitGettersSetters { 18 | val global = outer.global 19 | } 20 | val changes = performRefactoring() 21 | }.changes 22 | 23 | @Test 24 | @ScalaVersion(doesNotMatch = "2.12") 25 | def oneVarFromMany() = new FileSet { 26 | """ 27 | package oneFromMany 28 | class Demo(val a: String, /*(*/var i: Int/*)*/ ) { 29 | def doNothing = () 30 | } 31 | """ becomes 32 | """ 33 | package oneFromMany 34 | class Demo(val a: String, /*(*/private var _i: Int/*)*/ ) { 35 | def i = { 36 | _i 37 | } 38 | 39 | def i_=(i: Int) = { 40 | _i = i 41 | } 42 | def doNothing = () 43 | } 44 | """ 45 | } applyRefactoring(explicitGettersSetters) 46 | 47 | @Test 48 | @ScalaVersion(doesNotMatch = "2.12") 49 | def oneValFromMany() = new FileSet { 50 | """ 51 | package oneFromMany 52 | class Demo(val a: String, /*(*/val i: Int/*)*/ ) { 53 | def doNothing = () 54 | } 55 | """ becomes 56 | """ 57 | package oneFromMany 58 | class Demo(val a: String, /*(*/_i: Int/*)*/ ) { 59 | def i = { 60 | _i 61 | } 62 | 63 | def doNothing = () 64 | } 65 | """ 66 | } applyRefactoring(explicitGettersSetters) 67 | 68 | @Test 69 | @ScalaVersion(doesNotMatch = "2.12") 70 | def singleVal() = new FileSet { 71 | """ 72 | package oneFromMany 73 | class Demo( /*(*/val i: Int/*)*/ ) 74 | """ becomes 75 | """ 76 | package oneFromMany 77 | class Demo( /*(*/_i: Int/*)*/ ) { 78 | def i = { 79 | _i 80 | } 81 | } 82 | """ 83 | } applyRefactoring(explicitGettersSetters) 84 | 85 | @Test 86 | @ScalaVersion(doesNotMatch = "2.12") 87 | def singleValWithEmptyBody() = new FileSet { 88 | """ 89 | package oneFromMany 90 | class Demo( /*(*/val i: Int/*)*/ ) { 91 | 92 | } 93 | """ becomes 94 | """ 95 | package oneFromMany 96 | class Demo( /*(*/_i: Int/*)*/ ) { 97 | def i = { 98 | _i 99 | } 100 | } 101 | """ 102 | } applyRefactoring(explicitGettersSetters) 103 | } 104 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/implementations/extraction/ExtractionsTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.implementations.extraction 2 | 3 | import scala.tools.refactoring.tests.util.TestHelper 4 | import scala.tools.refactoring.implementations.extraction.Extractions 5 | import org.junit.Assert._ 6 | 7 | class ExtractionsTest extends TestHelper with Extractions { 8 | @Test 9 | def findExtractionTargets() = { 10 | val s = toSelection(""" 11 | object O{ 12 | def fn = { 13 | val a = 1 14 | /*(*/2 * a/*)*/ 15 | } 16 | } 17 | """) 18 | 19 | TestCollector.collect(s) 20 | assertEquals(2, TestCollector.extractionTargets.length) 21 | } 22 | 23 | @Test 24 | def noExtractionTargetsForSyntheticScopes() = { 25 | val s = toSelection(""" 26 | object O{ 27 | def fn = 28 | 1 :: 2 :: /*(*/Nil/*)*/ 29 | } 30 | """) 31 | 32 | TestCollector.collect(s) 33 | assertEquals(2, TestCollector.extractionTargets.length) 34 | } 35 | 36 | @Test 37 | def noExtractionTargetsForCasesWithSelectedPattern() = { 38 | val s = toSelection(""" 39 | object O{ 40 | 1 match { 41 | case /*(*/i/*)*/ => i 42 | } 43 | } 44 | """) 45 | 46 | TestCollector.collect(s) 47 | assertEquals(1, TestCollector.extractionTargets.length) 48 | } 49 | 50 | object TestCollector extends ExtractionCollector[Extraction] { 51 | var extractionTargets: List[ExtractionTarget] = Nil 52 | 53 | def isValidExtractionSource(s: Selection) = true 54 | 55 | def createExtractions(source: Selection, targets: List[ExtractionTarget], name: String) = { 56 | extractionTargets = targets 57 | Nil 58 | } 59 | } 60 | } -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsBaseTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.implementations.imports 7 | 8 | import implementations.OrganizeImports 9 | import sourcegen.Formatting 10 | import tests.util.TestHelper 11 | import tests.util.TestRefactoring 12 | 13 | abstract class OrganizeImportsBaseTest extends TestHelper with TestRefactoring { 14 | 15 | abstract class OrganizeImportsRefatoring(pro: FileSet, formatting: Formatting = new Formatting{}) extends TestRefactoringImpl(pro) { 16 | val refactoring = new OrganizeImports { 17 | val global = OrganizeImportsBaseTest.this.global 18 | override val dropScalaPackage = formatting.dropScalaPackage 19 | override val lineDelimiter = formatting.lineDelimiter 20 | } 21 | type RefactoringParameters = refactoring.RefactoringParameters 22 | val params: RefactoringParameters 23 | def mkChanges = performRefactoring(params) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsCollapseSelectorsToWildcardTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.implementations.imports 2 | 3 | import scala.tools.refactoring.implementations.OrganizeImports 4 | 5 | class OrganizeImportsCollapseSelectorsToWildcardTest extends OrganizeImportsBaseTest { 6 | 7 | def organize(exclude: Set[String] = Set())(pro: FileSet) = new OrganizeImportsRefatoring(pro) { 8 | import refactoring._ 9 | val maxIndividualImports = 2 10 | val oiConfig = OrganizeImports.OrganizeImportsConfig( 11 | importsStrategy = Some(OrganizeImports.ImportsStrategy.CollapseImports), 12 | collapseToWildcardConfig = Some(OrganizeImports.CollapseToWildcardConfig(maxIndividualImports, exclude))) 13 | val params = new RefactoringParameters(deps = Dependencies.FullyRecompute, config = Some(oiConfig)) 14 | }.mkChanges 15 | 16 | @Test 17 | def collapseImportSelectorsToWildcard() = new FileSet { 18 | """ 19 | import scala.math.{BigDecimal, BigInt, Numeric} 20 | 21 | object A { 22 | (BigDecimal, BigInt, Numeric) 23 | }""" becomes 24 | """ 25 | import scala.math._ 26 | 27 | object A { 28 | (BigDecimal, BigInt, Numeric) 29 | }""" 30 | } applyRefactoring organize() 31 | 32 | @Test 33 | def dontCollapseImportsWhenRename() = new FileSet { 34 | """ 35 | package acme 36 | import scala.math.{BigDecimal, BigInt, Numeric => N} 37 | 38 | object A { 39 | (BigDecimal, BigInt, N) 40 | }""" isNotModified 41 | } applyRefactoring organize() 42 | 43 | @Test 44 | def dontCollapseWhenCollidingWithExplicitImport() = new FileSet { 45 | """ 46 | import scala.collection.immutable.{HashSet, BitSet, HashMap} 47 | import scala.collection.mutable.{ArrayStack, ArrayBuilder, ArrayBuffer} 48 | 49 | object MyObject { 50 | (BitSet, HashMap, HashSet) 51 | (ArrayBuffer, ArrayBuilder, ArrayStack) 52 | }""" becomes 53 | """ 54 | import scala.collection.immutable._ 55 | import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, ArrayStack} 56 | 57 | object MyObject { 58 | (BitSet, HashMap, HashSet) 59 | (ArrayBuffer, ArrayBuilder, ArrayStack) 60 | }""" 61 | } applyRefactoring organize() 62 | 63 | @Test 64 | def dontCollapseWhenPackageInExcludes() = new FileSet { 65 | val before = """ 66 | import scala.collection.immutable.{BitSet, HashMap, HashSet} 67 | 68 | object MyObject { 69 | (BitSet, HashMap, HashSet) 70 | }""" 71 | 72 | before becomes before 73 | } applyRefactoring organize(Set("scala.collection.immutable")) 74 | 75 | } 76 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsEndOfLineTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package tests.implementations.imports 3 | 4 | import scala.tools.refactoring.implementations.OrganizeImports 5 | import scala.tools.refactoring.implementations.OrganizeImports.Dependencies 6 | import scala.tools.refactoring.sourcegen.Formatting 7 | 8 | class OrganizeImportsEndOfLineTest extends OrganizeImportsBaseTest { 9 | private def organizeCustomized( 10 | formatting: Formatting, 11 | groupPkgs: List[String] = List("java", "scala", "org", "com"), 12 | useWildcards: Set[String] = Set("scalaz", "scalaz.Scalaz"), 13 | dependencies: Dependencies.Value = Dependencies.FullyRecompute, 14 | organizeLocalImports: Boolean = true)(pro: FileSet) = new OrganizeImportsRefatoring(pro, formatting) { 15 | val oiConfig = OrganizeImports.OrganizeImportsConfig( 16 | importsStrategy = Some(OrganizeImports.ImportsStrategy.ExpandImports), 17 | wildcards = useWildcards, 18 | groups = groupPkgs) 19 | val params = { 20 | new refactoring.RefactoringParameters( 21 | deps = dependencies, 22 | organizeLocalImports = organizeLocalImports, 23 | config = Some(oiConfig)) 24 | } 25 | }.mkChanges 26 | 27 | private def organizeWithUnixEOL(pro: FileSet) = organizeCustomized(formatting = new Formatting { override def lineDelimiter = "\n" })(pro) 28 | private def organizeWithWindowsEOL(pro: FileSet) = organizeCustomized(formatting = new Formatting { override def lineDelimiter = "\r\n" })(pro) 29 | 30 | @Test 31 | def shouldPreserveUnixLineSeparator_v1() = new FileSet { 32 | "package testunix\n\nimport scala.util.Try\nimport java.util.List" becomes 33 | "package testunix\n\n" 34 | } applyRefactoring organizeWithUnixEOL 35 | 36 | @Test 37 | def shouldPreserveUnixLineSeparator_v2() = new FileSet { 38 | "package testunix\n\nimport scala.util.Try\nimport java.util.List\n" becomes 39 | "package testunix\n\n" 40 | } applyRefactoring organizeWithUnixEOL 41 | 42 | @Test 43 | def shouldPreserveUnixLineSeparator_v3() = new FileSet { 44 | "package testunix\n\nimport scala.util.Try\nimport java.util.List\n\nclass A(val t: Try[Int], val l: List[Int])" becomes 45 | "package testunix\n\nimport java.util.List\n\nimport scala.util.Try\n\nclass A(val t: Try[Int], val l: List[Int])" 46 | } applyRefactoring organizeWithUnixEOL 47 | 48 | @Test 49 | def shouldPreserveUnixLineSeparator_v4() = new FileSet { 50 | "package testunix\n\nimport scala.util.Try\nimport scala.util.Either\n\nclass A(val t: Try[Int], val l: Either[Int, Int])" becomes 51 | "package testunix\n\nimport scala.util.Either\nimport scala.util.Try\n\nclass A(val t: Try[Int], val l: Either[Int, Int])" 52 | } applyRefactoring organizeWithUnixEOL 53 | 54 | @Test 55 | def shouldPreserveWindowsLineSeparator_v1() = new FileSet { 56 | "package testwin\r\n\r\nimport scala.util.Try\r\nimport java.util.List" becomes 57 | "package testwin\r\n\r\n" 58 | } applyRefactoring organizeWithWindowsEOL 59 | 60 | @Test 61 | def shouldPreserveWindowsLineSeparator_v2() = new FileSet { 62 | "package testwin\r\n\r\nimport scala.util.Try\r\nimport java.util.List\r\n" becomes 63 | "package testwin\r\n\r\n" 64 | } applyRefactoring organizeWithWindowsEOL 65 | 66 | @Test 67 | def shouldPreserveWindowsLineSeparator_v3() = new FileSet { 68 | "package testwin\r\n\r\nimport scala.util.Try\r\nimport java.util.List\r\n\r\nclass A(val t: Try[Int], val l: List[Int])" becomes 69 | "package testwin\r\n\r\nimport java.util.List\r\n\r\nimport scala.util.Try\r\n\r\nclass A(val t: Try[Int], val l: List[Int])" 70 | } applyRefactoring organizeWithWindowsEOL 71 | 72 | @Test 73 | def shouldPreserveWindowsLineSeparator_v4() = new FileSet { 74 | "package testwin\r\n\r\nimport scala.util.Try\r\nimport scala.util.Either\r\n\r\nclass A(val t: Try[Int], val l: Either[Int, Int])" becomes 75 | "package testwin\r\n\r\nimport scala.util.Either\r\nimport scala.util.Try\r\n\r\nclass A(val t: Try[Int], val l: Either[Int, Int])" 76 | } applyRefactoring organizeWithWindowsEOL 77 | } 78 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/implementations/imports/OrganizeImportsWildcardsTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.implementations.imports 7 | 8 | import scala.tools.refactoring.implementations.OrganizeImports 9 | 10 | class OrganizeImportsWildcardsTest extends OrganizeImportsBaseTest { 11 | 12 | def organize(groups: Set[String])(pro: FileSet) = new OrganizeImportsRefatoring(pro) { 13 | import refactoring._ 14 | val oiConfig = OrganizeImports.OrganizeImportsConfig( 15 | importsStrategy = Some(OrganizeImports.ImportsStrategy.PreserveWildcards), 16 | wildcards = groups) 17 | val params = new RefactoringParameters(deps = Dependencies.FullyRecompute, config = Some(oiConfig)) 18 | }.mkChanges 19 | 20 | val source = """ 21 | import scala.collection.mutable.Set 22 | import org.xml.sax.Attributes 23 | import Set._ 24 | 25 | trait Temp { 26 | // we need some code that use the imports 27 | val z: (Attributes, Set[String]) 28 | println(apply("")) 29 | } 30 | """ 31 | 32 | @Test 33 | def noGrouping() = new FileSet { 34 | source becomes 35 | """ 36 | import org.xml.sax.Attributes 37 | import scala.collection.mutable.Set 38 | import scala.collection.mutable.Set.apply 39 | 40 | trait Temp { 41 | // we need some code that use the imports 42 | val z: (Attributes, Set[String]) 43 | println(apply("")) 44 | } 45 | """ 46 | } applyRefactoring organize(Set()) 47 | 48 | @Test 49 | def simpleWildcard() = new FileSet { 50 | source becomes 51 | """ 52 | import org.xml.sax.Attributes 53 | import scala.collection.mutable.Set 54 | import scala.collection.mutable.Set._ 55 | 56 | trait Temp { 57 | // we need some code that use the imports 58 | val z: (Attributes, Set[String]) 59 | println(apply("")) 60 | } 61 | """ 62 | } applyRefactoring organize(Set("scala.collection.mutable.Set")) 63 | 64 | @Test 65 | def renamedImport() = new FileSet { 66 | """ 67 | import java.lang.Integer.{valueOf => vo} 68 | import java.lang.Integer.toBinaryString 69 | import java.lang.String.valueOf 70 | 71 | trait Temp { 72 | valueOf(5) 73 | vo("5") 74 | toBinaryString(27) 75 | } 76 | """ becomes 77 | """ 78 | import java.lang.Integer._ 79 | import java.lang.Integer.{valueOf => vo} 80 | import java.lang.String.valueOf 81 | 82 | trait Temp { 83 | valueOf(5) 84 | vo("5") 85 | toBinaryString(27) 86 | } 87 | """ 88 | } applyRefactoring organize(Set("java.lang.Integer")) 89 | 90 | @Test 91 | def multipleImportsOneWildcard() = new FileSet { 92 | """ 93 | import java.lang.Integer.valueOf 94 | import java.lang.Integer.toBinaryString 95 | import java.lang.Double.toHexString 96 | 97 | trait Temp { 98 | valueOf("5") 99 | toBinaryString(27) 100 | toHexString(5) 101 | } 102 | """ becomes 103 | """ 104 | import java.lang.Double.toHexString 105 | import java.lang.Integer._ 106 | 107 | trait Temp { 108 | valueOf("5") 109 | toBinaryString(27) 110 | toHexString(5) 111 | } 112 | """ 113 | } applyRefactoring organize(Set("java.lang.Integer")) 114 | 115 | } 116 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/implementations/imports/UnusedImportsFinderTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.implementations.imports 7 | 8 | import implementations.UnusedImportsFinder 9 | import tests.util.TestHelper 10 | import scala.tools.refactoring.util.UniqueNames 11 | 12 | class UnusedImportsFinderTest extends TestHelper { 13 | outer => 14 | 15 | def findUnusedImports(expected: String, src: String): Unit = { 16 | 17 | val unuseds = global.ask { () => 18 | new UnusedImportsFinder { 19 | 20 | val global = outer.global 21 | 22 | val unit = global.unitOfFile(addToCompiler(UniqueNames.basename(), src)) 23 | 24 | def compilationUnitOfFile(f: AbstractFile) = Some(unit) 25 | 26 | val unuseds = findUnusedImports(unit) 27 | }.unuseds 28 | } 29 | 30 | org.junit.Assert.assertEquals(expected, unuseds.mkString(", ")) 31 | } 32 | 33 | @Test 34 | def simpleUnusedType() = findUnusedImports( 35 | "(ListBuffer,2)", 36 | """ 37 | import scala.collection.mutable.ListBuffer 38 | 39 | object Main {val s: String = "" } 40 | """ 41 | ) 42 | 43 | @Test 44 | def typeIsUsedAsVal() = findUnusedImports( 45 | "", 46 | """ 47 | import scala.collection.mutable.ListBuffer 48 | 49 | object Main {val s = new ListBuffer[Int] } 50 | """ 51 | ) 52 | 53 | @Test 54 | def typeIsImportedFrom() = findUnusedImports( 55 | "", 56 | """ 57 | class Forest { 58 | class Tree 59 | } 60 | 61 | class UsesTrees { 62 | val forest = new Forest 63 | import forest._ 64 | val x = new Tree 65 | } 66 | """ 67 | ) 68 | 69 | @Test 70 | def wildcardImports() = findUnusedImports( 71 | "", 72 | """ 73 | import scala.util.control.Exception._ 74 | 75 | class UsesTrees { 76 | val plugin = ScalaPlugin.plugin 77 | import plugin._ 78 | () 79 | } 80 | """ 81 | ) 82 | 83 | @Test 84 | def wildcardImportsFromValsAreIgnored() = findUnusedImports( 85 | "", 86 | """ 87 | object ScalaPlugin { 88 | var plugin: String = _ 89 | } 90 | 91 | class UsesTrees { 92 | val plugin = ScalaPlugin.plugin 93 | import plugin._ 94 | () 95 | } 96 | """ 97 | ) 98 | 99 | @Test 100 | def importFromJavaClass() = findUnusedImports( 101 | "", 102 | """ 103 | import java.util.Date 104 | 105 | object ScalaPlugin { 106 | import Date._ 107 | val max = parse(null) 108 | } 109 | """ 110 | ) 111 | 112 | @Test 113 | def ignoreImportIsNeverunused() = findUnusedImports( 114 | "", 115 | """ 116 | import java.util.{Date => _, _} 117 | 118 | object NoDate { 119 | var x: Stack[Int] = null 120 | } 121 | """ 122 | ) 123 | 124 | // more tests are in organize imports 125 | } 126 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/sourcegen/CustomFormattingTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.sourcegen 7 | 8 | import tests.util.TestHelper 9 | import org.junit.Assert._ 10 | import sourcegen.SourceGenerator 11 | import scala.tools.refactoring.implementations.OrganizeImports 12 | import scala.tools.refactoring.tests.util.TestRefactoring 13 | 14 | 15 | class CustomFormattingTest extends TestHelper with TestRefactoring with SourceGenerator { 16 | 17 | @volatile 18 | private var surroundingImport = "" 19 | 20 | override def spacingAroundMultipleImports = surroundingImport 21 | 22 | abstract class OrganizeImportsRefatoring(pro: FileSet) extends TestRefactoringImpl(pro) { 23 | val refactoring = new OrganizeImports { 24 | val global = CustomFormattingTest.this.global 25 | override def spacingAroundMultipleImports = surroundingImport 26 | } 27 | type RefactoringParameters = refactoring.RefactoringParameters 28 | val params: RefactoringParameters 29 | def mkChanges = performRefactoring(params) 30 | } 31 | 32 | def organize(pro: FileSet) = new OrganizeImportsRefatoring(pro) { 33 | val config = OrganizeImports.OrganizeImportsConfig( 34 | importsStrategy = Some(OrganizeImports.ImportsStrategy.CollapseImports) 35 | ) 36 | val params = new RefactoringParameters(config = Some(config)) 37 | }.mkChanges 38 | 39 | 40 | @Test 41 | def testSingleSpace(): Unit = { 42 | 43 | val ast = treeFrom(""" 44 | package test 45 | import scala.collection.{MapLike, MapProxy} 46 | """) 47 | 48 | surroundingImport = " " 49 | 50 | assertEquals(""" 51 | package test 52 | import scala.collection.{ MapLike, MapProxy } 53 | """, createText(ast, Some(ast.pos.source))) 54 | } 55 | 56 | @Test 57 | def collapse() = { 58 | surroundingImport = " " 59 | 60 | new FileSet { 61 | """ 62 | import java.lang.String 63 | import java.lang.Object 64 | 65 | object Main {val s: String = ""; var o: Object = null} 66 | """ becomes 67 | """ 68 | import java.lang.{ Object, String } 69 | 70 | object Main {val s: String = ""; var o: Object = null} 71 | """ 72 | } applyRefactoring organize 73 | } 74 | } 75 | 76 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/sourcegen/LayoutTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.sourcegen 7 | 8 | import org.junit.Test 9 | import org.junit.Assert._ 10 | import sourcegen._ 11 | 12 | import language.postfixOps 13 | 14 | class LayoutTest { 15 | 16 | @Test 17 | def simpleConcatenation(): Unit = { 18 | assertEquals("ab", Fragment("a") ++ Fragment("b") asText) 19 | assertEquals("abc", Fragment("a") ++ Fragment("b") ++ Fragment("c") asText) 20 | assertEquals("ab", Layout("a") ++ Layout("b") asText) 21 | assertEquals("abc", Layout("a") ++ Fragment("b") ++ Layout("c") asText) 22 | assertEquals("abc", Fragment("a") ++ Layout("b") ++ Fragment("c") asText) 23 | } 24 | 25 | @Test 26 | def concatenationsWithEmpty(): Unit = { 27 | val N = NoLayout 28 | val F = EmptyFragment 29 | 30 | assertEquals("", N asText) 31 | assertEquals("", F asText) 32 | assertEquals("", N ++ F asText) 33 | assertEquals("", N ++ N asText) 34 | assertEquals("", F ++ N asText) 35 | assertEquals("", F ++ F asText) 36 | 37 | assertEquals("a", Fragment("a") ++ F asText) 38 | assertEquals("a", Fragment("a") ++ N asText) 39 | 40 | assertEquals("a", F ++ Fragment("a") asText) 41 | assertEquals("a", N ++ Fragment("a") asText) 42 | 43 | assertEquals("ab", Fragment("a") ++ F ++ Layout("b") asText) 44 | assertEquals("ab", Layout("a") ++ N ++ Fragment("b")asText) 45 | } 46 | 47 | @Test 48 | def complexConcatenations(): Unit = { 49 | val a = Layout("a") 50 | val b = Layout("b") 51 | val c = Layout("c") 52 | 53 | (Fragment(a, b, c) ++ Fragment(a, b, c)) match { 54 | case Fragment(a, b, c) => 55 | assertEquals("a", a.asText) 56 | assertEquals("bcab", b.asText) 57 | assertEquals("c", c.asText) 58 | } 59 | 60 | (Fragment(a, b, c) ++ a) match { 61 | case Fragment(a, b, c) => 62 | assertEquals("a", a.asText) 63 | assertEquals("b", b.asText) 64 | assertEquals("ca", c.asText) 65 | } 66 | 67 | (b ++ Fragment(a, b, c)) match { 68 | case Fragment(a, b, c) => 69 | assertEquals("ba", a.asText) 70 | assertEquals("b", b.asText) 71 | assertEquals("c", c.asText) 72 | } 73 | 74 | (Fragment("a") ++ Fragment("b")) match { 75 | case Fragment(a, b, c) => 76 | assertEquals("", a.asText) 77 | assertEquals("ab", b.asText) 78 | assertEquals("", c.asText) 79 | } 80 | } 81 | 82 | @Test 83 | def preserveRequisites(): Unit = { 84 | val r = Requisite.allowSurroundingWhitespace(",") 85 | val a = Fragment("a") 86 | val b = Fragment("b") 87 | val x = Layout("x") 88 | 89 | assertEquals("a,x", a ++ r ++ x asText) 90 | assertEquals("a,b", a ++ r ++ b asText) 91 | } 92 | 93 | @Test 94 | def requisitesAreBetween(): Unit = { 95 | val r = Requisite.allowSurroundingWhitespace(",") 96 | val a = Fragment(Layout("a"), Layout("b"), Layout("c")) 97 | val b = Fragment(Layout("x"), Layout("y"), Layout("z")) 98 | 99 | assertEquals("abc,xyz", a ++ r ++ b asText) 100 | } 101 | 102 | @Test 103 | def requisitesAreOnlyUsesWhenNeeded1(): Unit = { 104 | val r = Requisite.allowSurroundingWhitespace(",") 105 | val a = Fragment(Layout("a"), Layout("b"), Layout(",")) 106 | val b = Fragment(Layout("x"), Layout("y"), Layout("z")) 107 | 108 | assertEquals("ab,xyz", a ++ r ++ b asText) 109 | } 110 | 111 | @Test 112 | def requisitesAreOnlyUsesWhenNeeded2(): Unit = { 113 | val r = Requisite.allowSurroundingWhitespace(",") 114 | val a = Fragment(Layout("a"), Layout("b"), Layout("c")) 115 | val b = Fragment(Layout(","), Layout("y"), Layout("z")) 116 | 117 | assertEquals("abc,yzabc", a ++ r ++ b ++ a asText) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/sourcegen/SourceHelperTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.sourcegen 7 | 8 | import tests.util.TestHelper 9 | import org.junit.Assert._ 10 | import sourcegen.SourceUtils 11 | 12 | class SourceHelperTest extends TestHelper { 13 | 14 | import SourceUtils._ 15 | 16 | @Test 17 | def liftSingleLineComment(): Unit = { 18 | 19 | assertEquals(("abc ", " //x"), splitComment("abc//x")) 20 | 21 | assertEquals(("x x", " /**/ "), splitComment("x/**/x")) 22 | 23 | assertEquals(("5 *5", " /**/ "), splitComment("5/**/*5")) 24 | 25 | assertEquals(("5 *5", " /*/**/*/ "), splitComment("5/*/**/*/*5")) 26 | 27 | assertEquals(("4 /2", " /*/**/*/ "), splitComment("4/*/**/*//2")) 28 | } 29 | 30 | @Test 31 | def multiplication() = { 32 | assertEquals(""" 33 | object A { 34 | val r = 3 35 | val p = r * r 36 | }""", stripComment(""" 37 | object A { 38 | val r = 3 39 | val p = r/**/* r 40 | }""")) 41 | } 42 | 43 | @Test 44 | def stripCommentInClass() = { 45 | assertEquals(stripWhitespacePreservers(""" 46 | class A { 47 | def extractFrom(): Int = { 48 | val a = 1 49 | a + 1 ▒ 50 | } 51 | }"""), stripComment(""" 52 | class A { 53 | def extractFrom(): Int = { 54 | val a = 1 55 | /*(*/ a + 1/*)*/ 56 | } 57 | }""")) 58 | } 59 | } 60 | 61 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/transformation/TransformableSelectionTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.transformation 2 | 3 | import scala.tools.refactoring.tests.util.TestHelper 4 | import org.junit.Assert._ 5 | import scala.tools.refactoring.transformation.TransformableSelections 6 | import scala.tools.refactoring.tests.util.TextSelections 7 | 8 | class TransformableSelectionTest extends TestHelper with TransformableSelections { 9 | import global._ 10 | 11 | val t123 = Literal(Constant(123)) 12 | val tprint123 = 13 | "object O { /*(*/println(123)/*)*/ }".selection.selectedTopLevelTrees.head 14 | 15 | implicit class StringToSel(src: String) { 16 | lazy val root = treeFrom(src) 17 | lazy val selection = { 18 | val textSelection = TextSelections.extractOne(src) 19 | FileSelection(root.pos.source.file, root, textSelection.from, textSelection.to) 20 | } 21 | 22 | def assertReplacement(mkTrans: Selection => Transformation[Tree, Tree]) = { 23 | val trans = mkTrans(selection) 24 | val result = trans(root) 25 | 26 | new { 27 | def toFail() = 28 | assertTrue(result.isEmpty) 29 | 30 | def toBecome(expectedSrc: String) = { 31 | val (expected, actual) = global.ask { () => 32 | (expectedSrc.root.toString(), result.get.toString()) 33 | } 34 | assertEquals(expected, actual) 35 | } 36 | 37 | def toBecomeTreeWith(assertion: Tree => Unit) = { 38 | assertion(result.get) 39 | } 40 | } 41 | } 42 | } 43 | 44 | @Test 45 | def replaceSingleStatement() = global.ask { () => """ 46 | object O{ 47 | def f = /*(*/1/*)*/ 48 | } 49 | """.assertReplacement(_.replaceBy(t123)).toBecome(""" 50 | object O{ 51 | def f = 123 52 | } 53 | """) 54 | } 55 | 56 | @Test 57 | def replaceSingleStatementInArgument() = global.ask { () => """ 58 | object O{ 59 | println(/*(*/1/*)*/) 60 | } 61 | """.assertReplacement(_.replaceBy(t123)).toBecome(""" 62 | object O{ 63 | println(123) 64 | } 65 | """) 66 | } 67 | 68 | @Test 69 | def replaceSequence() = global.ask { () => """ 70 | object O{ 71 | def f = { 72 | /*(*/println(1) 73 | println(2)/*)*/ 74 | println(3) 75 | } 76 | } 77 | """.assertReplacement(_.replaceBy(t123)).toBecome(""" 78 | object O{ 79 | def f = { 80 | 123 81 | println(3) 82 | } 83 | } 84 | """) 85 | } 86 | 87 | @Test 88 | def replaceAllExpressionsInBlock() = global.ask { () => """ 89 | object O{ 90 | def f = { 91 | /*(*/println(1) 92 | println(2) 93 | println(3)/*)*/ 94 | } 95 | } 96 | """.assertReplacement(_.replaceBy(tprint123)).toBecome(""" 97 | object O{ 98 | def f = println(123) 99 | } 100 | """) 101 | } 102 | 103 | @Test 104 | def replaceAllExpressionsInBlockPreservingHierarchy() = global.ask { () => """ 105 | object O{ 106 | def f = { 107 | /*(*/println(1) 108 | println(2) 109 | println(3)/*)*/ 110 | } 111 | } 112 | """.assertReplacement(_.replaceBy(tprint123, preserveHierarchy = true)).toBecomeTreeWith { t => 113 | val preservedBlock = t.find { 114 | // the new block must have an empty tree as its last expression 115 | case Block(stats, EmptyTree) => true 116 | case _ => false 117 | } 118 | assertTrue(preservedBlock.isDefined) 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/util/FreshCompilerForeachTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring.tests.util 6 | 7 | import org.junit.After 8 | import scala.tools.refactoring.util.CompilerInstance 9 | 10 | trait FreshCompilerForeachTest extends TestHelper { 11 | 12 | // We are experiencing instable test runs, maybe it helps when we 13 | // use a fresh compiler for each test case: 14 | override val global = (new CompilerInstance).compiler 15 | 16 | @After 17 | def shutdownCompiler(): Unit = { 18 | global.askShutdown 19 | } 20 | } 21 | 22 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/util/SourceHelpersTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.util 2 | 3 | import org.junit.Test 4 | import scala.tools.refactoring.util.SourceHelpers 5 | import org.junit.Assert._ 6 | import scala.tools.refactoring.util.SourceWithSelection 7 | 8 | class SourceHelpersTest { 9 | @Test 10 | def isRangeWithinWithTrivialArgs(): Unit = { 11 | testIsRangeWithin("", SourceWithSelection("x", 0, 1), false) 12 | testIsRangeWithin("x", SourceWithSelection("x", 0, 1), true) 13 | testIsRangeWithin("xx", SourceWithSelection("x", 0, 1), false) 14 | } 15 | 16 | @Test 17 | def isRangeWithEmptySelections(): Unit = { 18 | testIsRangeWithin("", SourceWithSelection("", 0, 0), true) 19 | testIsRangeWithin("a", SourceWithSelection("a", 0, 0), true) 20 | testIsRangeWithin("b", SourceWithSelection("a", 0, 0), false) 21 | } 22 | 23 | @Test 24 | def isRangeWithinWithSimpleExamples(): Unit = { 25 | testIsRangeWithin("abab", SourceWithSelection("0abab5", 4, 5), true) 26 | testIsRangeWithin("abab", SourceWithSelection("0abab5", 1, 5), true) 27 | testIsRangeWithin("abab", SourceWithSelection("0abab5", 1, 6), false) 28 | testIsRangeWithin("abab", SourceWithSelection("0abab5", 2, 4), true) 29 | testIsRangeWithin("ab", SourceWithSelection("012a", 3, 4), false) 30 | testIsRangeWithin("abba", SourceWithSelection("012abba", 3, 4), true) 31 | } 32 | 33 | private def testIsRangeWithin(text: String, selection: SourceWithSelection, expected: Boolean): Unit = { 34 | val result = SourceHelpers.isRangeWithin(text, selection) 35 | assertEquals(expected, result) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/util/TestRefactoring.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2005-2010 LAMP/EPFL 3 | */ 4 | 5 | package scala.tools.refactoring 6 | package tests.util 7 | 8 | import common.Change 9 | import common.InteractiveScalaCompiler 10 | import scala.tools.refactoring.common.InteractiveScalaCompiler 11 | 12 | trait TestRefactoring extends TestHelper { 13 | 14 | class PreparationException(cause: String) extends Exception(cause) 15 | class RefactoringException(cause: String) extends Exception(cause) 16 | 17 | abstract class TestRefactoringImpl(project: FileSet) { 18 | 19 | trait TestProjectIndex extends GlobalIndexes { 20 | this: Refactoring => 21 | 22 | val global = TestRefactoring.this.global 23 | 24 | lazy val trees = { 25 | project.javaSources.foreach { case (code, filename) => addToCompiler(filename, code) } 26 | project.sources.map { case(code, filename) => addToCompiler(filename, code) }.map(global.unitOfFile(_).body) 27 | } 28 | 29 | override val index = global.ask { () => 30 | val cuIndexes = trees map (_.pos.source.file) map { file => 31 | global.unitOfFile(file).body 32 | } map CompilationUnitIndex.apply 33 | GlobalIndex(cuIndexes) 34 | } 35 | } 36 | 37 | val refactoring: MultiStageRefactoring with InteractiveScalaCompiler 38 | 39 | def preparationResult() = global.ask { () => 40 | refactoring.prepare(selection(refactoring, project)) 41 | } 42 | 43 | def preparationResult(selection: => refactoring.Selection) = global.ask { () => 44 | refactoring.prepare(selection) 45 | } 46 | 47 | def performRefactoring(): List[Change] = { 48 | performRefactoring(selection(refactoring, project), null.asInstanceOf[refactoring.RefactoringParameters]) 49 | } 50 | 51 | def performRefactoring(parameters: refactoring.RefactoringParameters): List[Change] = { 52 | performRefactoring(selection(refactoring, project), parameters) 53 | } 54 | 55 | def performRefactoring( 56 | selection: => refactoring.Selection, 57 | parameters: refactoring.RefactoringParameters): List[Change] = global.ask { () => 58 | preparationResult(selection) match { 59 | case Right(prepare) => 60 | refactoring.perform(selection, prepare, parameters) match { 61 | case Right(modifications) => modifications 62 | case Left(error) => throw new RefactoringException(error.cause) 63 | } 64 | case Left(error) => throw new PreparationException(error.cause) 65 | } 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/util/TextSelectionsTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring.tests.util 2 | 3 | import org.junit.Test 4 | import org.junit.Assert.assertEquals 5 | import TextSelections.Range 6 | 7 | 8 | class TextSelectionsTest { 9 | @Test(expected = classOf[IllegalArgumentException]) 10 | def testWithEmptyString(): Unit = { 11 | TextSelections.extractOne("") 12 | } 13 | 14 | @Test(expected = classOf[IllegalArgumentException]) 15 | def testWithMoreThanOneSelection(): Unit = { 16 | TextSelections.extractOne("/*(*/ /*)*/ /*<-*/") 17 | } 18 | 19 | @Test 20 | def testWithOneEmptySelection(): Unit = { 21 | assertEquals(Range(0, 0), TextSelections.extractOne("/*<-*/")) 22 | assertEquals(Range(5, 5), TextSelections.extractOne("/*(*//*)*/")) 23 | } 24 | 25 | @Test 26 | def testWithOneNormalSelection(): Unit = { 27 | assertEquals(Range(5, 6), TextSelections.extractOne("/*(*/ /*)*/)")) 28 | } 29 | 30 | @Test 31 | def testWithOneSetCursorFromRightSelection(): Unit = { 32 | assertEquals(Range(5, 6), TextSelections.extractOne("012345/*<-cursor*/")) 33 | assertEquals(Range(5, 6), TextSelections.extractOne("012345/*<-cursor-0*/")) 34 | assertEquals(Range(4, 5), TextSelections.extractOne("012345/*<-cursor-1*/")) 35 | assertEquals(Range(3, 4), TextSelections.extractOne("012345/*<-cursor-2*/")) 36 | } 37 | 38 | @Test(expected = classOf[IllegalArgumentException]) 39 | def testWithCursorInvalidCursorFromRightSelection(): Unit = { 40 | TextSelections.extractOne("0123/*<-cursor-18*/") 41 | } 42 | 43 | @Test 44 | def testWithOneSetCursorFromLeftSelection(): Unit = { 45 | assertEquals(Range(14, 15), TextSelections.extractOne(" /*cursor->*/45678")) 46 | assertEquals(Range(14, 15), TextSelections.extractOne("/*0-cursor->*/45678")) 47 | assertEquals(Range(15, 16), TextSelections.extractOne("/*1-cursor->*/45678")) 48 | assertEquals(Range(16, 17), TextSelections.extractOne("/*2-cursor->*/45678")) 49 | } 50 | 51 | @Test(expected = classOf[IllegalArgumentException]) 52 | def testWithCursorInvalidCursorFromLeftSelection(): Unit = { 53 | TextSelections.extractOne("/*cursor->*/") 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/util/UnionFindInitTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package tests.util 3 | 4 | import scala.util.Random 5 | import scala.tools.refactoring.util.UnionFind 6 | import org.junit.Test 7 | import org.junit.Assert._ 8 | 9 | class UnionFindInitTest { 10 | 11 | // We test this on 100 randomly colored Nodes 12 | val colorString = Array("Red", "Blue", "Green", "Yellow", "Blue") 13 | class Node(val color: Int){override def toString() = this.hashCode().toString() + "( " + colorString(color) + ")"} 14 | val testNodes: List[Node] = List.fill(100){new Node(Random.nextInt(5))} 15 | val uf = new UnionFind[this.Node]() 16 | 17 | @Test 18 | def firstInsertedNodesShouldBeTheirOwnParents() = { 19 | for (node <- testNodes) uf.find(node) 20 | // inserted Nodes are their Parents 21 | assertTrue(testNodes.forall{(x) => uf.find(x) == x}) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/test/scala/scala/tools/refactoring/tests/util/UnionFindTest.scala: -------------------------------------------------------------------------------- 1 | package scala.tools.refactoring 2 | package tests.util 3 | 4 | import scala.util.Random 5 | import scala.tools.refactoring.util.UnionFind 6 | import org.junit.Test 7 | import org.junit.Assert._ 8 | import org.junit.Before 9 | 10 | class UnionFindTest { 11 | 12 | // We test this on 100 randomly colored Nodes 13 | val colorString = Array("Red", "Blue", "Green", "Yellow", "Blue") 14 | class Node(val color: Int){override def toString() = this.hashCode().toString() + "( " + colorString(color) + ")"} 15 | val testNodes: List[Node] = List.fill(100){new Node(Random.nextInt(5))} 16 | val uf = new UnionFind[this.Node]() 17 | 18 | @Before 19 | def unknownNodesShouldNotThrowWhenUnited() = { 20 | for (node1 <- testNodes; 21 | node2 <- testNodes if node1.color == node2.color) uf.union(node1, node2) 22 | } 23 | 24 | @Test 25 | def atMostFiveRepsInUF() = { 26 | val nodesFromUF = testNodes.map(uf.find(_)).distinct 27 | val repsLength = nodesFromUF.length 28 | assertTrue(s"Expected five reps in the union-find, found $repsLength !", repsLength <= 5) 29 | } 30 | 31 | @Test 32 | def nodesInRelationIfAndOnlyIfWithSameColor(): Unit ={ 33 | def sameColorImpliesRelation(x: Node, y: Node) = x.color != y.color || uf.find(x) == uf.find(y) 34 | def relationImpliesSameColor(x: Node, y: Node) = uf.find(x) != uf.find(y) || x.color == y.color 35 | for (x <- testNodes; 36 | y <- testNodes) { 37 | val px = uf.find(x) 38 | val py = uf.find(y) 39 | assertTrue(s"problem found with node $x (parent $px) and $y (parent $py)", sameColorImpliesRelation(x, y) && relationImpliesSameColor(x,y)) 40 | } 41 | } 42 | 43 | def colorRepresentant(): Array[Node] = { 44 | val representants = new Array[Node](5) 45 | for (c <- 0 to 4){ 46 | representants(c) = testNodes.collectFirst{ case (x: Node) if (x.color == c) => uf.find(x)}.getOrElse(uf.find(new Node(c))) 47 | } 48 | representants 49 | } 50 | 51 | @Test 52 | def classRepresentantIsUnique(): Unit ={ 53 | val reps = colorRepresentant() 54 | testNodes.foreach{(x) => { 55 | val xColor = colorString(x.color) 56 | val xColorRep = reps(x.color) 57 | assertTrue(s"problem found with $x yet the representant of $xColor is $xColorRep", uf.find(x) == xColorRep)} 58 | } 59 | } 60 | 61 | @Test 62 | def findIsIdempotent(): Unit ={ 63 | assertTrue(testNodes.forall{(x) => val p = uf.find(x); p == uf.find(p)}) 64 | } 65 | 66 | @Test 67 | def nodesForWhichFindIsIdentityAreReps(): Unit ={ 68 | val selfRepresented = testNodes.filter{ (n)=> uf.find(n) == n } 69 | val reps = colorRepresentant() 70 | assertTrue(reps.forall{(x) => selfRepresented.contains(x)}) 71 | assertTrue(selfRepresented.forall{(x) => reps.contains(x)}) 72 | } 73 | 74 | @Test 75 | def equivalenceClassGivesAColor(): Unit ={ 76 | def myClassIsExactlyMyColor(n: Node): Boolean = { 77 | val myClass = uf.equivalenceClass(n) 78 | val inClassImpliesSameColor = testNodes.forall{ (x) => !myClass.contains(x) || x.color == n.color} 79 | val sameColorImpliesInclass = testNodes.forall{ (x) => x.color != n.color || myClass.contains(x)} 80 | inClassImpliesSameColor && sameColorImpliesInclass 81 | } 82 | // A bit overkill to do this on more than representants 83 | assertTrue(colorRepresentant().forall{myClassIsExactlyMyColor}) 84 | } 85 | 86 | } 87 | --------------------------------------------------------------------------------